Skip to content

Commit 137bfb9

Browse files
authored
fix(sqla_factory): added an async context manager in SQLAASyncPersistence (#630)
1 parent 135bbc0 commit 137bfb9

File tree

2 files changed

+18
-12
lines changed

2 files changed

+18
-12
lines changed

polyfactory/factories/sqlalchemy_factory.py

+9-7
Original file line numberDiff line numberDiff line change
@@ -52,16 +52,18 @@ def __init__(self, session: AsyncSession) -> None:
5252
self.session = session
5353

5454
async def save(self, data: T) -> T:
55-
self.session.add(data)
56-
await self.session.commit()
57-
await self.session.refresh(data)
55+
async with self.session as session:
56+
session.add(data)
57+
await session.commit()
58+
await session.refresh(data)
5859
return data
5960

6061
async def save_many(self, data: list[T]) -> list[T]:
61-
self.session.add_all(data)
62-
await self.session.commit()
63-
for batch_item in data:
64-
await self.session.refresh(batch_item)
62+
async with self.session as session:
63+
session.add_all(data)
64+
await session.commit()
65+
for batch_item in data:
66+
await session.refresh(batch_item)
6567
return data
6668

6769

tests/sqlalchemy_factory/test_sqlalchemy_factory_common.py

+9-5
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
func,
1919
inspect,
2020
orm,
21+
select,
2122
text,
2223
types,
2324
)
@@ -343,13 +344,15 @@ class Factory(SQLAlchemyFactory[AsyncModel]):
343344
__async_session__ = session_config(session)
344345
__model__ = AsyncModel
345346

346-
result = await Factory.create_async()
347-
assert inspect(result).persistent # type: ignore[union-attr]
347+
instance = await Factory.create_async()
348+
result = await session.scalar(select(AsyncModel).where(AsyncModel.id == instance.id))
349+
assert result
348350

349351
batch_result = await Factory.create_batch_async(size=2)
350352
assert len(batch_result) == 2
351353
for batch_item in batch_result:
352-
assert inspect(batch_item).persistent # type: ignore[union-attr]
354+
result = await session.scalar(select(AsyncModel).where(AsyncModel.id == batch_item.id))
355+
assert result
353356

354357

355358
@pytest.mark.parametrize(
@@ -392,8 +395,9 @@ class Factory(SQLAlchemyFactory[AsyncRefreshModel]):
392395
test_int = Ignore()
393396
test_bool = Ignore()
394397

395-
result = await Factory.create_async()
396-
assert inspect(result).persistent # type: ignore[union-attr]
398+
instance = await Factory.create_async()
399+
result = await session.scalar(select(AsyncRefreshModel).where(AsyncRefreshModel.id == instance.id))
400+
assert result
397401
assert result.test_datetime is not None
398402
assert isinstance(result.test_datetime, datetime)
399403
assert result.test_str == "test_str"

0 commit comments

Comments
 (0)