Skip to content

Commit

Permalink
fix: ensure ABC are not considered a factory type (#628)
Browse files Browse the repository at this point in the history
  • Loading branch information
adhtruong authored Jan 15, 2025
1 parent d2ef554 commit 135bbc0
Show file tree
Hide file tree
Showing 4 changed files with 41 additions and 5 deletions.
5 changes: 4 additions & 1 deletion polyfactory/factories/base.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import copy
import inspect
from abc import ABC, abstractmethod
from collections import Counter, abc, deque
from contextlib import suppress
Expand Down Expand Up @@ -404,7 +405,9 @@ def is_factory_type(cls, annotation: Any) -> bool:
:param annotation: A type annotation.
:returns: Boolean dictating whether the annotation is a factory type
"""
return any(factory.is_supported_type(annotation) for factory in BaseFactory._base_factories)
return not inspect.isabstract(annotation) and any(
factory.is_supported_type(annotation) for factory in BaseFactory._base_factories
)

@classmethod
def is_batch_factory_type(cls, annotation: Any) -> bool:
Expand Down
4 changes: 2 additions & 2 deletions tests/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ class Person(BaseModel):
birthday: Union[datetime, date]


class PersonFactoryWithoutDefaults(ModelFactory):
class PersonFactoryWithoutDefaults(ModelFactory[Person]):
__model__ = Person


Expand All @@ -39,5 +39,5 @@ class PersonFactoryWithDefaults(PersonFactoryWithoutDefaults):
birthday = datetime(2021 - 33, 1, 1)


class PetFactory(ModelFactory):
class PetFactory(ModelFactory[Pet]):
__model__ = Pet
35 changes: 34 additions & 1 deletion tests/test_base_factories.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Any, Dict
from typing import Any, Callable, Dict

import pytest

from pydantic.main import BaseModel

from polyfactory.exceptions import ParameterException
from polyfactory.factories import DataclassFactory
from polyfactory.factories.base import BaseFactory
from polyfactory.factories.pydantic_factory import ModelFactory
Expand Down Expand Up @@ -96,3 +98,34 @@ class Foo:
def test_create_factory_from_base_factory_without_providing_a_model_raises_error() -> None:
with pytest.raises(TypeError):
BaseFactory.create_factory()


def test_abstract_classes_are_ignored() -> None:
@dataclass
class Base(ABC):
@abstractmethod
def f(self) -> int: ...

@dataclass
class Concrete(Base):
def f(self) -> int:
return 1

@dataclass
class Model:
single: Base

class ModelFactory(DataclassFactory[Model]):
@classmethod
def get_provider_map(cls) -> Dict[type, Callable[[], Any]]:
return {
**super().get_provider_map(),
Base: Concrete,
}

result = ModelFactory.build()
assert isinstance(result, Model)
assert isinstance(result.single, Concrete)

with pytest.raises(ParameterException, match="Unsupported type: "):
DataclassFactory.create_factory(Model).build()
2 changes: 1 addition & 1 deletion tests/test_pydantic_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,7 +258,7 @@ def test_factory_use_construct() -> None:
# factory should pass values without validation
invalid_age = "non_valid_age"
non_validated_pet = PetFactory.build(factory_use_construct=True, age=invalid_age)
assert non_validated_pet.age == invalid_age
assert non_validated_pet.age == invalid_age # type: ignore[comparison-overlap]

with pytest.raises(ValidationError):
PetFactory.build(age=invalid_age)
Expand Down

0 comments on commit 135bbc0

Please sign in to comment.