diff --git a/docs/guide/optimizer.md b/docs/guide/optimizer.md index 8114b01c..9693bb33 100644 --- a/docs/guide/optimizer.md +++ b/docs/guide/optimizer.md @@ -293,3 +293,193 @@ class OrderItem: `total` now will be properly optimized since it points to a `@model_property` decorated attribute, which contains the required information for optimizing it. + +## Optimizing polymorphic queries + +The optimizer has dedicated support for polymorphic queries, that is, fields which return an interface. +The optimizer will handle optimizing any subtypes of the interface as necessary. This is supported on top level queries +as well as relations between models. +See the following sections for how this interacts with your models. + +### Using Django Polymorphic + +If you are already using the [Django Polymorphic](https://django-polymorphic.readthedocs.io/en/stable/) library, +polymorphic queries work out of the box. + +```python title="models.py" +from django.db import models +from polymorphic.models import PolymorphicModel + +class Project(PolymorphicModel): + topic = models.CharField(max_length=255) + +class ResearchProject(Project): + supervisor = models.CharField(max_length=30) + +class ArtProject(Project): + artist = models.CharField(max_length=30) +``` + +```python title="types.py" +import strawberry +import strawberry_django +from . import models + + +@strawberry_django.interface(models.Project) +class ProjectType: + topic: strawberry.auto + + +@strawberry_django.type(models.ResearchProject) +class ResearchProjectType(ProjectType): + supervisor: strawberry.auto + + +@strawberry_django.type(models.ArtProject) +class ArtProjectType(ProjectType): + artist: strawberry.auto + + +@strawberry.type +class Query: + projects: list[ProjectType] = strawberry_django.field() +``` + +The `projects` field will return either ResearchProjectType or ArtProjectType, matching on whether it is a +ResearchProject or ArtProject. The optimizer will make sure to only select those fields from subclasses which are +requested in the GraphQL query in the same way that it does normally. + +> [!WARNING] +> The optimizer does not filter your QuerySet and Django will return +> all instances of your model, regardless of whether their type exists in your GraphQL schema or not. +> Make sure you have a corresponding type for every model subclass or add a `get_queryset` method to your +> GraphQL interface type to filter out unwanted subtypes. +> Otherwise you might receive an error like +> `Abstract type 'ProjectType' must resolve to an Object type at runtime for field 'Query.projects'.` + +### Using Model-Utils InheritanceManager + +Models using `InheritanceManager` from [django-model-utils](https://django-model-utils.readthedocs.io/en/latest/) +are also supported. + +```python title="models.py" +from django.db import models +from model_utils.managers import InheritanceManager + +class Project(models.Model): + topic = models.CharField(max_length=255) + + objects = InheritanceManager() + +class ResearchProject(Project): + supervisor = models.CharField(max_length=30) + +class ArtProject(Project): + artist = models.CharField(max_length=30) +``` + +```python title="types.py" +import strawberry +import strawberry_django +from . import models + + +@strawberry_django.interface(models.Project) +class ProjectType: + topic: strawberry.auto + + +@strawberry_django.type(models.ResearchProject) +class ResearchProjectType(ProjectType): + supervisor: strawberry.auto + + +@strawberry_django.type(models.ArtProject) +class ArtProjectType(ProjectType): + artist: strawberry.auto + + +@strawberry.type +class Query: + projects: list[ProjectType] = strawberry_django.field() +``` + +The `projects` field will return either ResearchProjectType or ArtProjectType, matching on whether it is a +ResearchProject or ArtProject. The optimizer automatically calls `select_subclasses`, passing in any subtypes present +in your schema. + +> [!WARNING] +> The optimizer does not filter your QuerySet and Django will return +> all instances of your model, regardless of whether their type exists in your GraphQL schema or not. +> Make sure you have a corresponding type for every model subclass or add a `get_queryset` method to your +> GraphQL interface type to filter out unwanted subtypes. +> Otherwise you might receive an error like +> `Abstract type 'ProjectType' must resolve to an Object type at runtime for field 'Query.projects'.` + +> [!NOTE] +> If you have polymorphic relations (as in: a field that points to a model with subclasses), you need to make sure +> the manager being used to look up the related model is an `InheritanceManager`. +> Strawberry Django uses the model's [base manager](https://docs.djangoproject.com/en/5.1/topics/db/managers/#base-managers) +> by default, which is different from the standard `objects`. +> Either change your base manager to also be an `InheritanceManager` or set Strawberry Django to use the default +> manager: `DjangoOptimizerExtension(prefetch_custom_queryset=True)`. + +### Custom polymorphic solution + +The optimizer also supports polymorphism even if your models are not polymorphic. +`resolve_type` in the GraphQL interface type is used to tell GraphQL the actual type that should be used. + +```python title="models.py" +from django.db import models + +class Project(models.Model): + topic = models.CharField(max_length=255) + supervisor = models.CharField(max_length=30) + artist = models.CharField(max_length=30) + +``` + +```python title="types.py" +import strawberry +import strawberry_django +from . import models + + +@strawberry_django.interface(models.Project) +class ProjectType: + topic: strawberry.auto + + @classmethod + def resolve_type(cls, value, info, parent_type) -> str: + if not isinstance(value, models.Project): + raise TypeError() + if value.artist: + return 'ArtProjectType' + if value.supervisor: + return 'ResearchProjectType' + raise TypeError() + + @classmethod + def get_queryset(cls, qs, info): + return qs + + +@strawberry_django.type(models.ResearchProject) +class ResearchProjectType(ProjectType): + supervisor: strawberry.auto + + +@strawberry_django.type(models.ArtProject) +class ArtProjectType(ProjectType): + artist: strawberry.auto + + +@strawberry.type +class Query: + projects: list[ProjectType] = strawberry_django.field() +``` + +> [!WARNING] +> Make sure to add `get_queryset` to your interface type, to force the optimizer to use +> `prefetch_related`, otherwise this technique will not work for relation fields. diff --git a/poetry.lock b/poetry.lock index 737c0aaf..c2a1d39a 100644 --- a/poetry.lock +++ b/poetry.lock @@ -197,6 +197,21 @@ files = [ [package.dependencies] Django = ">=2.2" +[[package]] +name = "django-model-utils" +version = "5.0.0" +description = "Django model mixins and utilities" +optional = false +python-versions = ">=3.8" +groups = ["dev"] +files = [ + {file = "django_model_utils-5.0.0-py3-none-any.whl", hash = "sha256:fec78e6c323d565a221f7c4edc703f4567d7bb1caeafe1acd16a80c5ff82056b"}, + {file = "django_model_utils-5.0.0.tar.gz", hash = "sha256:041cdd6230d2fbf6cd943e1969318bce762272077f4ecd333ab2263924b4e5eb"}, +] + +[package.dependencies] +Django = ">=3.2" + [[package]] name = "django-polymorphic" version = "3.1.0" @@ -931,4 +946,4 @@ enum = ["django-choices-field"] [metadata] lock-version = "2.1" python-versions = ">=3.9,<4.0" -content-hash = "20e6563b36f4043c2b2fb1f9f6bc377f9864127a2f0e65e4871636e18b5cf05c" +content-hash = "18a410e1d60ae7c898b8fa134123f6f0c0a3f81d59d335631aba5b8d797badea" diff --git a/pyproject.toml b/pyproject.toml index 134bf07d..524ccb3b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -59,6 +59,7 @@ setuptools = "^77.0.3" psycopg2 = "^2.9.9" psycopg2-binary = "^2.9.9" django-tree-queries = "^0.19.0" +django-model-utils = "^5.0.0" [tool.poetry.extras] debug-toolbar = ["django-debug-toolbar"] diff --git a/strawberry_django/optimizer.py b/strawberry_django/optimizer.py index 830171dc..cba1c712 100644 --- a/strawberry_django/optimizer.py +++ b/strawberry_django/optimizer.py @@ -5,7 +5,7 @@ import copy import dataclasses import itertools -from collections import Counter, defaultdict +from collections import Counter from collections.abc import Callable from typing import ( TYPE_CHECKING, @@ -59,7 +59,11 @@ PrefetchInspector, get_model_field, get_model_fields, + get_possible_concrete_types, get_possible_type_definitions, + is_inheritance_manager, + is_inheritance_qs, + is_polymorphic_model, ) from .utils.typing import ( AnnotateCallable, @@ -95,12 +99,6 @@ _sentinel = object() _annotate_placeholder = "__annotated_placeholder__" -_interfaces: defaultdict[ - Schema, - dict[StrawberryObjectDefinition, list[StrawberryObjectDefinition]], -] = defaultdict( - dict, -) @dataclasses.dataclass @@ -706,6 +704,27 @@ def _get_hints_from_model_property( return store +def _must_use_prefetch_related( + config: OptimizerConfig, + field: StrawberryField, + model_field: models.ForeignKey | OneToOneRel, +) -> bool: + f_type = _get_django_type(field) + + # - If the field has a get_queryset method, use Prefetch so it will be respected + # - If the model is using django-polymorphic, + # use Prefetch so its custom queryset will be used, returning polymorphic models + return ( + (f_type and hasattr(f_type, "get_queryset")) + or is_polymorphic_model(model_field.related_model) + or is_inheritance_manager( + model_field.related_model._default_manager + if config.prefetch_custom_queryset + else model_field.related_model._base_manager # type: ignore + ) + ) + + def _get_hints_from_django_foreign_key( field: StrawberryField, field_definition: GraphQLObjectType, @@ -721,13 +740,9 @@ def _get_hints_from_django_foreign_key( cache: dict[type[models.Model], list[tuple[int, OptimizerStore]]], level: int = 0, ) -> OptimizerStore: - f_type = _get_django_type(field) - if f_type and hasattr(f_type, "get_queryset"): - # If the field has a get_queryset method, change strategy to Prefetch - # so it will be respected + if _must_use_prefetch_related(config, field, model_field): store = _get_hints_from_django_relation( field, - field_definition=field_definition, field_selection=field_selection, model_field=model_field, model_fieldname=model_fieldname, @@ -776,7 +791,6 @@ def _get_hints_from_django_foreign_key( def _get_hints_from_django_relation( field: StrawberryField, - field_definition: GraphQLObjectType, field_selection: FieldNode, model_field: ( models.ManyToManyField @@ -820,16 +834,34 @@ def _get_hints_from_django_relation( remote_field = model_field.remote_field remote_model = remote_field.model - field_store = _get_model_hints( - remote_model, - schema, - f_types[0], - parent_type=field_definition, - info=field_info, - config=config, - cache=cache, - level=level + 1, - ) + field_store = None + f_type = f_types[0] + subclasses = [] + for concrete_field_type in get_possible_concrete_types( + remote_model, schema, f_type + ): + django_definition = get_django_definition(concrete_field_type.origin) + if ( + django_definition + and django_definition.model != remote_model + and not django_definition.model._meta.abstract + and issubclass(django_definition.model, remote_model) + ): + subclasses.append(django_definition.model) + concrete_store = _get_model_hints( + remote_model, + schema, + concrete_field_type, + parent_type=_get_gql_definition(schema, concrete_field_type), + info=field_info, + config=config, + cache=cache, + level=level + 1, + ) + if concrete_store is not None: + field_store = ( + concrete_store if field_store is None else field_store | concrete_store + ) if field_store is None: return store @@ -877,6 +909,8 @@ def _get_hints_from_django_relation( info=field_info, related_field_id=related_field_id, ) + if is_inheritance_qs(base_qs): + base_qs = base_qs.select_subclasses(*subclasses) field_qs = field_store.apply(base_qs, info=field_info, config=config) field_prefetch = Prefetch(path, queryset=field_qs) field_prefetch._optimizer_sentinel = _sentinel # type: ignore @@ -956,7 +990,6 @@ def _get_hints_from_django_field( elif isinstance(model_field, relation_fields): store = _get_hints_from_django_relation( field, - field_definition=field_definition, field_selection=field_selection, model_field=model_field, model_fieldname=model_fieldname, @@ -985,6 +1018,7 @@ def _get_model_hints( prefix: str = "", cache: dict[type[models.Model], list[tuple[int, OptimizerStore]]] | None = None, level: int = 0, + subclass_collection: set[type[models.Model]] | None = None, ) -> OptimizerStore | None: cache = cache or {} @@ -1000,6 +1034,7 @@ def _get_model_hints( prefix=prefix, cache=cache, level=level, + subclass_collection=subclass_collection, ) # In case this is a Paginated field, the selected fields are inside results selection @@ -1014,17 +1049,56 @@ def _get_model_hints( prefix=prefix, cache=cache, level=level, + subclass_collection=subclass_collection, ) store = OptimizerStore() config = config or OptimizerConfig() dj_definition = get_django_definition(object_definition.origin) - if ( - dj_definition is None - or not issubclass(model, dj_definition.model) - or dj_definition.disable_optimization - ): + if dj_definition is None or dj_definition.disable_optimization: + return None + + if not issubclass(model, dj_definition.model): + # If this is a PolymorphicModel, also try to optimize fields in subclasses + # of the current model. + if not dj_definition.model._meta.abstract and issubclass( + dj_definition.model, model + ): + if subclass_collection is not None: + subclass_collection.add(dj_definition.model) + if is_polymorphic_model(model): + # These must be prefixed with app_label__ModelName___ (note three underscores) + # This is a special syntax for django-polymorphic: + # https://django-polymorphic.readthedocs.io/en/stable/advanced.html#polymorphic-filtering-for-fields-in-inherited-classes + return _get_model_hints( + dj_definition.model, + schema, + object_definition, + parent_type=parent_type, + info=info, + config=config, + prefix=f"{prefix}{dj_definition.model._meta.app_label}__{dj_definition.model._meta.model_name}___", + ) + if is_inheritance_manager(model._default_manager) and ( + path_from_parent := dj_definition.model._meta.get_path_from_parent( + model + ) + ): + prefix = LOOKUP_SEP.join( + p.join_field.get_accessor_name() for p in path_from_parent + ) + prefix += LOOKUP_SEP + return _get_model_hints( + dj_definition.model, + schema, + object_definition, + parent_type=parent_type, + info=info, + config=config, + prefix=prefix, + ) + return None dj_type_store = getattr(dj_definition, "store", None) @@ -1034,7 +1108,11 @@ def _get_model_hints( # Make sure that the model's pk is always selected when using only pk = model._meta.pk if pk is not None: - store.only.append(pk.attname) + store.only.append(prefix + pk.attname) + + # If this is a polymorphic Model, make sure to select its content type + if is_polymorphic_model(model): + store.only.extend(prefix + f for f in model.polymorphic_internal_model_fields) selections = [ field_data @@ -1148,6 +1226,7 @@ def _get_model_hints_from_connection( prefix: str = "", cache: dict[type[models.Model], list[tuple[int, OptimizerStore]]] | None = None, level: int = 0, + subclass_collection: set[type[models.Model]] | None = None, ) -> OptimizerStore | None: store = None @@ -1187,29 +1266,34 @@ def _get_model_hints_from_connection( if node.name.value != "node": continue - n_gql_definition = _get_gql_definition(schema, n_definition) - assert isinstance( - n_gql_definition, - (GraphQLObjectType, GraphQLInterfaceType), - ) - n_info = _generate_selection_resolve_info( - info, - nodes, - n_gql_definition, - e_gql_definition, - ) - - store = _get_model_hints( - model=model, - schema=schema, - object_definition=n_definition, - parent_type=n_gql_definition, - info=n_info, - config=config, - prefix=prefix, - cache=cache, - level=level, - ) + for concrete_n_type in get_possible_concrete_types( + model, schema, n_definition + ): + n_gql_definition = _get_gql_definition(schema, concrete_n_type) + assert isinstance( + n_gql_definition, + (GraphQLObjectType, GraphQLInterfaceType), + ) + n_info = _generate_selection_resolve_info( + info, + nodes, + n_gql_definition, + e_gql_definition, + ) + concrete_store = _get_model_hints( + model=model, + schema=schema, + object_definition=concrete_n_type, + parent_type=n_gql_definition, + info=n_info, + config=config, + prefix=prefix, + cache=cache, + level=level, + subclass_collection=subclass_collection, + ) + if concrete_store is not None: + store = concrete_store if store is None else store | concrete_store return store @@ -1225,6 +1309,7 @@ def _get_model_hints_from_paginated( prefix: str = "", cache: dict[type[models.Model], list[tuple[int, OptimizerStore]]] | None = None, level: int = 0, + subclass_collection: set[type[models.Model]] | None = None, ) -> OptimizerStore | None: store = None @@ -1234,35 +1319,41 @@ def _get_model_hints_from_paginated( n_type = n_type.resolve_type() n_definition = get_object_definition(n_type, strict=True) - n_gql_definition = _get_gql_definition( - schema, - get_object_definition(n_type, strict=True), - ) - assert isinstance(n_gql_definition, (GraphQLObjectType, GraphQLInterfaceType)) for selections in _get_selections(info, parent_type).values(): selection = selections[0] if selection.name.value != "results": continue - n_info = _generate_selection_resolve_info( - info, - selections, - n_gql_definition, - n_gql_definition, - ) + for concrete_n_type in get_possible_concrete_types(model, schema, n_definition): + n_gql_definition = _get_gql_definition( + schema, + concrete_n_type, + ) + assert isinstance( + n_gql_definition, (GraphQLObjectType, GraphQLInterfaceType) + ) + n_info = _generate_selection_resolve_info( + info, + selections, + n_gql_definition, + n_gql_definition, + ) - store = _get_model_hints( - model=model, - schema=schema, - object_definition=n_definition, - parent_type=n_gql_definition, - info=n_info, - config=config, - prefix=prefix, - cache=cache, - level=level, - ) + concrete_store = _get_model_hints( + model=model, + schema=schema, + object_definition=concrete_n_type, + parent_type=n_gql_definition, + info=n_info, + config=config, + prefix=prefix, + cache=cache, + level=level, + subclass_collection=subclass_collection, + ) + if concrete_store is not None: + store = concrete_store if store is None else store | concrete_store return store @@ -1330,41 +1421,28 @@ def optimize( if strawberry_type is None: return qs - for object_definition in get_possible_type_definitions(strawberry_type): - if object_definition.is_interface: - interface_definitions = _interfaces[schema].get(object_definition) - if interface_definitions is None: - interface_definitions = [] - for t in schema.schema_converter.type_map.values(): - t_definition = t.definition - if isinstance( - t_definition, StrawberryObjectDefinition - ) and issubclass(t_definition.origin, object_definition.origin): - interface_definitions.append(t_definition) - _interfaces[schema][object_definition] = interface_definitions - - object_definitions = [] - for interface_definition in interface_definitions: - dj_definition = get_django_definition(interface_definition.origin) - if dj_definition and issubclass(qs.model, dj_definition.model): - object_definitions.append(interface_definition) - else: - object_definitions = [object_definition] - - for inner_object_definition in object_definitions: - parent_type = _get_gql_definition(schema, inner_object_definition) - new_store = _get_model_hints( - qs.model, - schema, - inner_object_definition, - parent_type=parent_type, - info=info, - config=config, - ) - if new_store is not None: - store |= new_store + inheritance_qs = is_inheritance_qs(qs) + subclasses = set() if inheritance_qs else None + + for inner_object_definition in get_possible_concrete_types( + qs.model, schema, strawberry_type + ): + parent_type = _get_gql_definition(schema, inner_object_definition) + new_store = _get_model_hints( + qs.model, + schema, + inner_object_definition, + parent_type=parent_type, + info=info, + config=config, + subclass_collection=subclasses, + ) + if new_store is not None: + store |= new_store if store: + if inheritance_qs and subclasses: + qs = qs.select_subclasses(*subclasses) qs = store.apply(qs, info=info, config=config) qs_config = get_queryset_config(qs) qs_config.optimized = True diff --git a/strawberry_django/utils/inspect.py b/strawberry_django/utils/inspect.py index 1f72074e..b4421216 100644 --- a/strawberry_django/utils/inspect.py +++ b/strawberry_django/utils/inspect.py @@ -3,13 +3,17 @@ import dataclasses import functools import itertools +from collections import defaultdict +from collections.abc import Iterable from typing import ( TYPE_CHECKING, + Any, cast, ) from django.db.models.query import Prefetch, QuerySet from django.db.models.sql.where import WhereNode +from strawberry import Schema from strawberry.types import has_object_definition from strawberry.types.base import ( StrawberryContainer, @@ -20,11 +24,12 @@ from strawberry.types.lazy_type import LazyType from strawberry.types.union import StrawberryUnion from strawberry.utils.str_converters import to_camel_case -from typing_extensions import assert_never +from typing_extensions import TypeIs, assert_never from strawberry_django.fields.types import resolve_model_field_name from .pyutils import DictTree, dicttree_insersection_differs, dicttree_merge +from .typing import get_django_definition if TYPE_CHECKING: from collections.abc import Generator, Iterable @@ -34,6 +39,11 @@ from django.db.models.fields import Field from django.db.models.fields.reverse_related import ForeignObjectRel from django.db.models.sql.query import Query + from model_utils.managers import ( + InheritanceManagerMixin, + InheritanceQuerySetMixin, + ) + from polymorphic.models import PolymorphicModel @functools.lru_cache @@ -143,6 +153,95 @@ def get_possible_type_definitions( yield t.__strawberry_definition__ +try: + # Can't import PolymorphicModel, because it requires Django Apps to be ready + # Import polymorphic instead to check for its existence + import polymorphic # noqa: F401 + + def is_polymorphic_model(v: type) -> TypeIs[type[PolymorphicModel]]: + return getattr(v, "polymorphic_model_marker", False) is True + +except ImportError: + + def is_polymorphic_model(v: type) -> TypeIs[type[PolymorphicModel]]: + return False + + +try: + from model_utils.managers import InheritanceManagerMixin, InheritanceQuerySetMixin + + def is_inheritance_manager( + v: Any, + ) -> TypeIs[InheritanceManagerMixin]: + return isinstance(v, InheritanceManagerMixin) + + def is_inheritance_qs( + v: Any, + ) -> TypeIs[InheritanceQuerySetMixin]: + return isinstance(v, InheritanceQuerySetMixin) + +except ImportError: + + def is_inheritance_manager( + v: Any, + ) -> TypeIs[InheritanceManagerMixin]: + return False + + def is_inheritance_qs( + v: Any, + ) -> TypeIs[InheritanceQuerySetMixin]: + return False + + +def _can_optimize_subtypes(model: type[models.Model]) -> bool: + return is_polymorphic_model(model) or is_inheritance_manager(model._default_manager) + + +_interfaces: defaultdict[ + Schema, + dict[StrawberryObjectDefinition, list[StrawberryObjectDefinition]], +] = defaultdict(dict) + + +def get_possible_concrete_types( + model: type[models.Model], + schema: Schema, + strawberry_type: StrawberryObjectDefinition | StrawberryType, +) -> Iterable[StrawberryObjectDefinition]: + """Return the object definitions the optimizer should look at when optimizing a model. + + Returns any object definitions attached to either the model or one of its supertypes. + + If the model is one that supports polymorphism, by returning subtypes from its queryset, subtypes are also + looked at. Currently, this is supported for django-polymorphic and django-model-utils InheritanceManager. + """ + for object_definition in get_possible_type_definitions(strawberry_type): + if not object_definition.is_interface: + yield object_definition + continue + interface_definitions = _interfaces[schema].get(object_definition) + if interface_definitions is None: + interface_definitions = [] + for t in schema.schema_converter.type_map.values(): + t_definition = t.definition + if isinstance(t_definition, StrawberryObjectDefinition) and issubclass( + t_definition.origin, object_definition.origin + ): + interface_definitions.append(t_definition) + _interfaces[schema][object_definition] = interface_definitions + + for interface_definition in interface_definitions: + dj_definition = get_django_definition(interface_definition.origin) + if dj_definition and ( + issubclass(model, dj_definition.model) + or ( + _can_optimize_subtypes(model) + and issubclass(dj_definition.model, model) + ) + ): + yield interface_definition + + @dataclasses.dataclass(eq=True) class PrefetchInspector: """Prefetch hints.""" diff --git a/tests/django_settings.py b/tests/django_settings.py index 2c7bfe56..3387bb2f 100644 --- a/tests/django_settings.py +++ b/tests/django_settings.py @@ -114,5 +114,7 @@ "tests", "tests.projects", "tests.polymorphism", + "tests.polymorphism_custom", + "tests.polymorphism_inheritancemanager", ], ) diff --git a/tests/node_polymorphism/test_optimizer.py b/tests/node_polymorphism/test_optimizer.py index d53806ac..e678142b 100644 --- a/tests/node_polymorphism/test_optimizer.py +++ b/tests/node_polymorphism/test_optimizer.py @@ -1,5 +1,7 @@ import pytest +from tests.utils import assert_num_queries + from .models import ArtProject, ResearchProject from .schema import schema @@ -28,7 +30,9 @@ def test_polymorphic_interface_query(): } """ - result = schema.execute_sync(query) + # ContentType, base table, two subtables = 4 queries + with assert_num_queries(4): + result = schema.execute_sync(query) assert not result.errors assert result.data == { "projects": { diff --git a/tests/polymorphism/models.py b/tests/polymorphism/models.py index 6d6b1378..8f6d380b 100644 --- a/tests/polymorphism/models.py +++ b/tests/polymorphism/models.py @@ -2,13 +2,57 @@ from polymorphic.models import PolymorphicModel +class Company(models.Model): + name = models.CharField(max_length=100) + main_project = models.ForeignKey("Project", on_delete=models.CASCADE, null=True) + + class Meta: + ordering = ("name",) + + class Project(PolymorphicModel): + company = models.ForeignKey( + Company, + null=True, + blank=True, + on_delete=models.CASCADE, + related_name="projects", + ) topic = models.CharField(max_length=30) class ArtProject(Project): artist = models.CharField(max_length=30) + art_style = models.CharField(max_length=30) class ResearchProject(Project): supervisor = models.CharField(max_length=30) + research_notes = models.TextField() + + +class TechnicalProject(Project): + timeline = models.CharField(max_length=30) + + class Meta: # pyright: ignore [reportIncompatibleVariableOverride] + abstract = True + + +class SoftwareProject(TechnicalProject): + repository = models.CharField(max_length=255) + + +class EngineeringProject(TechnicalProject): + lead_engineer = models.CharField(max_length=255) + + +class AppProject(TechnicalProject): + repository = models.CharField(max_length=255) + + +class AndroidProject(AppProject): + android_version = models.CharField(max_length=15) + + +class IOSProject(AppProject): + ios_version = models.CharField(max_length=15) diff --git a/tests/polymorphism/schema.py b/tests/polymorphism/schema.py index 43cab9eb..4d6dde0d 100644 --- a/tests/polymorphism/schema.py +++ b/tests/polymorphism/schema.py @@ -1,9 +1,23 @@ +from typing import Optional + import strawberry import strawberry_django from strawberry_django.optimizer import DjangoOptimizerExtension +from strawberry_django.pagination import OffsetPaginated -from .models import ArtProject, Project, ResearchProject +from .models import ( + AndroidProject, + AppProject, + ArtProject, + Company, + EngineeringProject, + IOSProject, + Project, + ResearchProject, + SoftwareProject, + TechnicalProject, +) @strawberry_django.interface(Project) @@ -21,13 +35,63 @@ class ResearchProjectType(ProjectType): supervisor: strawberry.auto +@strawberry_django.interface(TechnicalProject) +class TechnicalProjectType(ProjectType): + timeline: strawberry.auto + + +@strawberry_django.type(SoftwareProject) +class SoftwareProjectType(TechnicalProjectType): + repository: strawberry.auto + + +@strawberry_django.type(EngineeringProject) +class EngineeringProjectType(TechnicalProjectType): + lead_engineer: strawberry.auto + + +@strawberry_django.interface(AppProject) +class AppProjectType(TechnicalProjectType): + repository: strawberry.auto + + +@strawberry_django.type(AndroidProject) +class AndroidProjectType(AppProjectType): + android_version: strawberry.auto + + +@strawberry_django.type(IOSProject) +class IOSProjectType(AppProjectType): + ios_version: strawberry.auto + + +@strawberry_django.type(Company) +class CompanyType: + name: strawberry.auto + projects: list[ProjectType] + main_project: Optional[ProjectType] + + @strawberry.type class Query: + companies: list[CompanyType] = strawberry_django.field() projects: list[ProjectType] = strawberry_django.field() + projects_paginated: list[ProjectType] = strawberry_django.field(pagination=True) + projects_offset_paginated: OffsetPaginated[ProjectType] = ( + strawberry_django.offset_paginated() + ) schema = strawberry.Schema( query=Query, - types=[ArtProjectType, ResearchProjectType], + types=[ + ArtProjectType, + ResearchProjectType, + TechnicalProjectType, + EngineeringProjectType, + SoftwareProjectType, + AndroidProjectType, + IOSProjectType, + ], extensions=[DjangoOptimizerExtension], ) diff --git a/tests/polymorphism/test_optimizer.py b/tests/polymorphism/test_optimizer.py index 8e16bc7c..5f00fc8f 100644 --- a/tests/polymorphism/test_optimizer.py +++ b/tests/polymorphism/test_optimizer.py @@ -1,6 +1,18 @@ import pytest +from django.db import DEFAULT_DB_ALIAS, connections +from django.test.utils import CaptureQueriesContext -from .models import ArtProject, ResearchProject +from tests.utils import assert_num_queries + +from .models import ( + AndroidProject, + ArtProject, + Company, + EngineeringProject, + IOSProject, + ResearchProject, + SoftwareProject, +) from .schema import schema @@ -24,7 +36,9 @@ def test_polymorphic_interface_query(): } """ - result = schema.execute_sync(query) + # ContentType, base table, two subtables = 4 queries + with assert_num_queries(4): + result = schema.execute_sync(query) assert not result.errors assert result.data == { "projects": [ @@ -36,3 +50,395 @@ def test_polymorphic_interface_query(): }, ] } + + +@pytest.mark.django_db(transaction=True) +def test_polymorphic_query_abstract_model(): + ap = ArtProject.objects.create(topic="Art", artist="Artist") + sp = SoftwareProject.objects.create( + topic="Software", repository="https://example.com", timeline="3 months" + ) + ep = EngineeringProject.objects.create( + topic="Engineering", lead_engineer="Elara Voss", timeline="6 years" + ) + + query = """\ + query { + projects { + __typename + topic + ... on ArtProjectType { + artist + } + ...on TechnicalProjectType { + timeline + } + ... on SoftwareProjectType { + repository + } + ...on EngineeringProjectType { + leadEngineer + } + } + } + """ + + with assert_num_queries(5): + result = schema.execute_sync(query) + assert not result.errors + assert result.data == { + "projects": [ + {"__typename": "ArtProjectType", "topic": ap.topic, "artist": ap.artist}, + { + "__typename": "SoftwareProjectType", + "topic": sp.topic, + "repository": sp.repository, + "timeline": sp.timeline, + }, + { + "__typename": "EngineeringProjectType", + "topic": ep.topic, + "leadEngineer": ep.lead_engineer, + "timeline": ep.timeline, + }, + ] + } + + +@pytest.mark.django_db(transaction=True) +def test_polymorphic_query_multiple_inheritance_levels(): + app1 = AndroidProject.objects.create( + topic="Software", + repository="https://example.com/android", + timeline="3 months", + android_version="14", + ) + app2 = IOSProject.objects.create( + topic="Software", + repository="https://example.com/ios", + timeline="5 months", + ios_version="16", + ) + ep = EngineeringProject.objects.create( + topic="Engineering", lead_engineer="Elara Voss", timeline="6 years" + ) + + query = """\ + query { + projects { + __typename + topic + ...on TechnicalProjectType { + timeline + } + ...on AppProjectType { + repository + } + ...on AndroidProjectType { + androidVersion + } + ...on IOSProjectType { + iosVersion + } + ...on EngineeringProjectType { + leadEngineer + } + } + } + """ + + # Project Table, Content Type, AndroidProject, IOSProject, EngineeringProject = 5 + with assert_num_queries(5): + result = schema.execute_sync(query) + assert not result.errors + assert result.data == { + "projects": [ + { + "__typename": "AndroidProjectType", + "topic": app1.topic, + "repository": app1.repository, + "timeline": app1.timeline, + "androidVersion": app1.android_version, + }, + { + "__typename": "IOSProjectType", + "topic": app2.topic, + "repository": app2.repository, + "timeline": app2.timeline, + "iosVersion": app2.ios_version, + }, + { + "__typename": "EngineeringProjectType", + "topic": ep.topic, + "leadEngineer": ep.lead_engineer, + "timeline": ep.timeline, + }, + ] + } + + +@pytest.mark.django_db(transaction=True) +def test_polymorphic_query_abstract_model_on_field(): + ep = EngineeringProject.objects.create( + topic="Engineering", lead_engineer="Elara Voss", timeline="6 years" + ) + company = Company.objects.create(name="Company", main_project=ep) + + query = """\ + query { + companies { + name + mainProject { + __typename + topic + ...on TechnicalProjectType { + timeline + } + ...on EngineeringProjectType { + leadEngineer + } + } + } + } + """ + + with assert_num_queries(4): + result = schema.execute_sync(query) + assert not result.errors + assert result.data == { + "companies": [ + { + "name": company.name, + "mainProject": { + "__typename": "EngineeringProjectType", + "topic": ep.topic, + "leadEngineer": ep.lead_engineer, + "timeline": ep.timeline, + }, + } + ] + } + + +@pytest.mark.django_db(transaction=True) +def test_polymorphic_query_optimization_working(): + ap = ArtProject.objects.create(topic="Art", artist="Artist") + rp = ResearchProject.objects.create(topic="Research", supervisor="Supervisor") + + query = """\ + query { + projects { + __typename + topic + ... on ArtProjectType { + artist + } + ... on ResearchProjectType { + supervisor + } + } + } + """ + + with CaptureQueriesContext(connection=connections[DEFAULT_DB_ALIAS]) as ctx: + result = schema.execute_sync(query) + # validate that we're not selecting extra fields + assert any("artist" in q["sql"] for q in ctx.captured_queries) + assert not any("research_notes" in q["sql"] for q in ctx.captured_queries) + assert not any("art_style" in q["sql"] for q in ctx.captured_queries) + assert not result.errors + assert result.data == { + "projects": [ + {"__typename": "ArtProjectType", "topic": ap.topic, "artist": ap.artist}, + { + "__typename": "ResearchProjectType", + "topic": rp.topic, + "supervisor": rp.supervisor, + }, + ] + } + + +@pytest.mark.django_db(transaction=True) +def test_polymorphic_paginated_query(): + ap = ArtProject.objects.create(topic="Art", artist="Artist") + rp = ResearchProject.objects.create(topic="Research", supervisor="Supervisor") + + query = """\ + query { + projectsPaginated { + __typename + topic + ... on ArtProjectType { + artist + } + ... on ResearchProjectType { + supervisor + } + } + } + """ + + # ContentType, base table, two subtables = 4 queries + with assert_num_queries(4): + result = schema.execute_sync(query) + assert not result.errors + assert result.data == { + "projectsPaginated": [ + {"__typename": "ArtProjectType", "topic": ap.topic, "artist": ap.artist}, + { + "__typename": "ResearchProjectType", + "topic": rp.topic, + "supervisor": rp.supervisor, + }, + ] + } + + +@pytest.mark.django_db(transaction=True) +def test_polymorphic_offset_paginated_query(): + ap = ArtProject.objects.create(topic="Art", artist="Artist") + rp = ResearchProject.objects.create(topic="Research", supervisor="Supervisor") + + query = """\ + query { + projectsOffsetPaginated { + totalCount + results { + __typename + topic + ... on ArtProjectType { + artist + } + ... on ResearchProjectType { + supervisor + } + } + } + } + """ + + # ContentType, base table, two subtables = 4 queries + 1 query for total count + with assert_num_queries(5): + result = schema.execute_sync(query) + assert not result.errors + assert result.data == { + "projectsOffsetPaginated": { + "totalCount": 2, + "results": [ + { + "__typename": "ArtProjectType", + "topic": ap.topic, + "artist": ap.artist, + }, + { + "__typename": "ResearchProjectType", + "topic": rp.topic, + "supervisor": rp.supervisor, + }, + ], + } + } + + +@pytest.mark.django_db(transaction=True) +def test_polymorphic_relation(): + ap = ArtProject.objects.create(topic="Art", artist="Artist") + art_company = Company.objects.create(name="ArtCompany", main_project=ap) + + rp = ResearchProject.objects.create(topic="Research", supervisor="Supervisor") + research_company = Company.objects.create(name="ResearchCompany", main_project=rp) + + query = """\ + query { + companies { + name + mainProject { + __typename + topic + ... on ArtProjectType { + artist + } + ... on ResearchProjectType { + supervisor + } + } + } + } + """ + + # Company, ContentType, base table, two subtables = 5 queries + with assert_num_queries(5): + result = schema.execute_sync(query) + assert not result.errors + assert result.data == { + "companies": [ + { + "name": art_company.name, + "mainProject": { + "__typename": "ArtProjectType", + "topic": ap.topic, + "artist": ap.artist, + }, + }, + { + "name": research_company.name, + "mainProject": { + "__typename": "ResearchProjectType", + "topic": rp.topic, + "supervisor": rp.supervisor, + }, + }, + ] + } + + +@pytest.mark.django_db(transaction=True) +def test_polymorphic_nested_list(): + company = Company.objects.create(name="Company") + ap = ArtProject.objects.create(company=company, topic="Art", artist="Artist") + rp = ResearchProject.objects.create( + company=company, topic="Research", supervisor="Supervisor" + ) + + query = """\ + query { + companies { + name + projects { + __typename + topic + ... on ArtProjectType { + artist + } + ... on ResearchProjectType { + supervisor + } + } + } + } + """ + + # Company, ContentType, base table, two subtables = 5 queries + with assert_num_queries(5): + result = schema.execute_sync(query) + assert not result.errors + assert result.data == { + "companies": [ + { + "name": "Company", + "projects": [ + { + "__typename": "ArtProjectType", + "topic": ap.topic, + "artist": ap.artist, + }, + { + "__typename": "ResearchProjectType", + "topic": rp.topic, + "supervisor": rp.supervisor, + }, + ], + } + ] + } diff --git a/tests/polymorphism_custom/__init__.py b/tests/polymorphism_custom/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/polymorphism_custom/models.py b/tests/polymorphism_custom/models.py new file mode 100644 index 00000000..fe666f7e --- /dev/null +++ b/tests/polymorphism_custom/models.py @@ -0,0 +1,34 @@ +from django.db import models + + +class Company(models.Model): + name = models.CharField(max_length=100) + main_project = models.ForeignKey( + "Project", null=True, blank=True, on_delete=models.CASCADE + ) + + class Meta: + ordering = ("name",) + + +class Project(models.Model): + company = models.ForeignKey( + Company, + null=True, + blank=True, + on_delete=models.CASCADE, + related_name="projects", + ) + topic = models.CharField(max_length=30) + artist = models.CharField(max_length=30, blank=True) + supervisor = models.CharField(max_length=30, blank=True) + research_notes = models.TextField(blank=True) + + class Meta: + constraints = ( + models.CheckConstraint( + check=(models.Q(artist="") | models.Q(supervisor="")) + & (~models.Q(topic="") | ~models.Q(topic="")), + name="artist_xor_supervisor", + ), + ) diff --git a/tests/polymorphism_custom/schema.py b/tests/polymorphism_custom/schema.py new file mode 100644 index 00000000..f4229cb5 --- /dev/null +++ b/tests/polymorphism_custom/schema.py @@ -0,0 +1,71 @@ +from typing import Any, Optional + +import strawberry +from graphql import GraphQLAbstractType, GraphQLResolveInfo +from strawberry import Info +from strawberry.relay import Node + +import strawberry_django +from strawberry_django.optimizer import DjangoOptimizerExtension +from strawberry_django.pagination import OffsetPaginated +from strawberry_django.relay import ListConnectionWithTotalCount + +from .models import Company, Project + + +@strawberry_django.interface(Project) +class ProjectType(Node): + topic: strawberry.auto + + @classmethod + def resolve_type( + cls, value: Any, info: GraphQLResolveInfo, parent_type: GraphQLAbstractType + ) -> str: + if not isinstance(value, Project): + raise TypeError + if value.artist: + return "ArtProjectType" + if value.supervisor: + return "ResearchProjectType" + raise TypeError + + @classmethod + def get_queryset(cls, qs, info: Info): + return qs + + +@strawberry_django.type(Project) +class ArtProjectType(ProjectType): + artist: strawberry.auto + + +@strawberry_django.type(Project) +class ResearchProjectType(ProjectType): + supervisor: strawberry.auto + + +@strawberry_django.type(Company) +class CompanyType: + name: strawberry.auto + main_project: Optional[ProjectType] + projects: list[ProjectType] + + +@strawberry.type +class Query: + companies: list[CompanyType] = strawberry_django.field() + projects: list[ProjectType] = strawberry_django.field() + projects_paginated: list[ProjectType] = strawberry_django.field(pagination=True) + projects_offset_paginated: OffsetPaginated[ProjectType] = ( + strawberry_django.offset_paginated() + ) + projects_connection: ListConnectionWithTotalCount[ProjectType] = ( + strawberry_django.connection() + ) + + +schema = strawberry.Schema( + query=Query, + types=[ArtProjectType, ResearchProjectType], + extensions=[DjangoOptimizerExtension], +) diff --git a/tests/polymorphism_custom/test_optimizer.py b/tests/polymorphism_custom/test_optimizer.py new file mode 100644 index 00000000..0ff454db --- /dev/null +++ b/tests/polymorphism_custom/test_optimizer.py @@ -0,0 +1,313 @@ +import pytest +from django.db import DEFAULT_DB_ALIAS, connections +from django.test.utils import CaptureQueriesContext + +from tests.utils import assert_num_queries + +from .models import Company, Project +from .schema import schema + + +@pytest.mark.django_db(transaction=True) +def test_polymorphic_interface_query(): + ap = Project.objects.create(topic="Art", artist="Artist") + rp = Project.objects.create(topic="Research", supervisor="Supervisor") + + query = """\ + query { + projects { + __typename + topic + ... on ArtProjectType { + artist + } + ... on ResearchProjectType { + supervisor + } + } + } + """ + + with assert_num_queries(1): + result = schema.execute_sync(query) + assert not result.errors + assert result.data == { + "projects": [ + {"__typename": "ArtProjectType", "topic": ap.topic, "artist": ap.artist}, + { + "__typename": "ResearchProjectType", + "topic": rp.topic, + "supervisor": rp.supervisor, + }, + ] + } + + +@pytest.mark.django_db(transaction=True) +def test_polymorphic_query_optimization_working(): + ap = Project.objects.create(topic="Art", artist="Artist") + rp = Project.objects.create(topic="Research", supervisor="Supervisor") + + query = """\ + query { + projects { + __typename + topic + ... on ArtProjectType { + artist + } + ... on ResearchProjectType { + supervisor + } + } + } + """ + + with CaptureQueriesContext(connection=connections[DEFAULT_DB_ALIAS]) as ctx: + result = schema.execute_sync(query) + # validate that we're not selecting extra fields + assert any("artist" in q["sql"] for q in ctx.captured_queries) + assert not any("research_notes" in q["sql"] for q in ctx.captured_queries) + assert not result.errors + assert result.data == { + "projects": [ + {"__typename": "ArtProjectType", "topic": ap.topic, "artist": ap.artist}, + { + "__typename": "ResearchProjectType", + "topic": rp.topic, + "supervisor": rp.supervisor, + }, + ] + } + + +@pytest.mark.django_db(transaction=True) +def test_polymorphic_interface_paginated(): + ap = Project.objects.create(topic="Art", artist="Artist") + rp = Project.objects.create(topic="Research", supervisor="Supervisor") + + query = """\ + query { + projectsPaginated { + __typename + topic + ... on ArtProjectType { + artist + } + ... on ResearchProjectType { + supervisor + } + } + } + """ + + with assert_num_queries(1): + result = schema.execute_sync(query) + assert not result.errors + assert result.data == { + "projectsPaginated": [ + {"__typename": "ArtProjectType", "topic": ap.topic, "artist": ap.artist}, + { + "__typename": "ResearchProjectType", + "topic": rp.topic, + "supervisor": rp.supervisor, + }, + ] + } + + +@pytest.mark.django_db(transaction=True) +def test_polymorphic_interface_offset_paginated(): + ap = Project.objects.create(topic="Art", artist="Artist") + rp = Project.objects.create(topic="Research", supervisor="Supervisor") + + query = """\ + query { + projectsOffsetPaginated { + totalCount + results { + __typename + topic + ... on ArtProjectType { + artist + } + ... on ResearchProjectType { + supervisor + } + } + } + } + """ + + with assert_num_queries(2): + result = schema.execute_sync(query) + assert not result.errors + assert result.data == { + "projectsOffsetPaginated": { + "totalCount": 2, + "results": [ + { + "__typename": "ArtProjectType", + "topic": ap.topic, + "artist": ap.artist, + }, + { + "__typename": "ResearchProjectType", + "topic": rp.topic, + "supervisor": rp.supervisor, + }, + ], + } + } + + +@pytest.mark.django_db(transaction=True) +def test_polymorphic_interface_connection(): + ap = Project.objects.create(topic="Art", artist="Artist") + rp = Project.objects.create(topic="Research", supervisor="Supervisor") + + query = """\ + query { + projectsConnection { + totalCount + edges { + node { + __typename + topic + ... on ArtProjectType { + artist + } + ... on ResearchProjectType { + supervisor + } + } + } + } + } + """ + + with assert_num_queries(2): + result = schema.execute_sync(query) + assert not result.errors + assert result.data == { + "projectsConnection": { + "totalCount": 2, + "edges": [ + { + "node": { + "__typename": "ArtProjectType", + "topic": ap.topic, + "artist": ap.artist, + } + }, + { + "node": { + "__typename": "ResearchProjectType", + "topic": rp.topic, + "supervisor": rp.supervisor, + } + }, + ], + } + } + + +@pytest.mark.django_db(transaction=True) +def test_polymorphic_relation(): + ap = Project.objects.create(topic="Art", artist="Artist") + art_company = Company.objects.create(name="ArtCompany", main_project=ap) + + rp = Project.objects.create(topic="Research", supervisor="Supervisor") + research_company = Company.objects.create(name="ResearchCompany", main_project=rp) + + query = """\ + query { + companies { + name + mainProject { + __typename + topic + ... on ArtProjectType { + artist + } + ... on ResearchProjectType { + supervisor + } + } + } + } + """ + + with assert_num_queries(2): + result = schema.execute_sync(query) + assert not result.errors + assert result.data == { + "companies": [ + { + "name": art_company.name, + "mainProject": { + "__typename": "ArtProjectType", + "topic": ap.topic, + "artist": ap.artist, + }, + }, + { + "name": research_company.name, + "mainProject": { + "__typename": "ResearchProjectType", + "topic": rp.topic, + "supervisor": rp.supervisor, + }, + }, + ] + } + + +@pytest.mark.django_db(transaction=True) +def test_polymorphic_nested_list(): + company = Company.objects.create(name="Company") + ap = Project.objects.create(company=company, topic="Art", artist="Artist") + rp = Project.objects.create( + company=company, topic="Research", supervisor="Supervisor" + ) + + query = """\ + query { + companies { + name + projects { + __typename + topic + ... on ArtProjectType { + artist + } + ... on ResearchProjectType { + supervisor + } + } + } + } + """ + + with assert_num_queries(2): + result = schema.execute_sync(query) + assert not result.errors + assert result.data == { + "companies": [ + { + "name": "Company", + "projects": [ + { + "__typename": "ArtProjectType", + "topic": ap.topic, + "artist": ap.artist, + }, + { + "__typename": "ResearchProjectType", + "topic": rp.topic, + "supervisor": rp.supervisor, + }, + ], + } + ] + } diff --git a/tests/polymorphism_inheritancemanager/__init__.py b/tests/polymorphism_inheritancemanager/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/polymorphism_inheritancemanager/models.py b/tests/polymorphism_inheritancemanager/models.py new file mode 100644 index 00000000..a1b50b15 --- /dev/null +++ b/tests/polymorphism_inheritancemanager/models.py @@ -0,0 +1,64 @@ +from django.db import models +from model_utils.managers import InheritanceManager + + +class Company(models.Model): + name = models.CharField(max_length=100) + main_project = models.ForeignKey("Project", on_delete=models.CASCADE, null=True) + + class Meta: + ordering = ("name",) + + +class Project(models.Model): + company = models.ForeignKey( + Company, + null=True, + blank=True, + on_delete=models.CASCADE, + related_name="projects", + ) + topic = models.CharField(max_length=30) + + base_objects = InheritanceManager() + objects = InheritanceManager() + + class Meta: + base_manager_name = "base_objects" + + +class ArtProject(Project): + artist = models.CharField(max_length=30) + art_style = models.CharField(max_length=30) + + +class ResearchProject(Project): + supervisor = models.CharField(max_length=30) + research_notes = models.TextField() + + +class TechnicalProject(Project): + timeline = models.CharField(max_length=30) + + class Meta: # pyright: ignore [reportIncompatibleVariableOverride] + abstract = True + + +class SoftwareProject(TechnicalProject): + repository = models.CharField(max_length=255) + + +class EngineeringProject(TechnicalProject): + lead_engineer = models.CharField(max_length=255) + + +class AppProject(TechnicalProject): + repository = models.CharField(max_length=255) + + +class AndroidProject(AppProject): + android_version = models.CharField(max_length=15) + + +class IOSProject(AppProject): + ios_version = models.CharField(max_length=15) diff --git a/tests/polymorphism_inheritancemanager/schema.py b/tests/polymorphism_inheritancemanager/schema.py new file mode 100644 index 00000000..ff4bb024 --- /dev/null +++ b/tests/polymorphism_inheritancemanager/schema.py @@ -0,0 +1,98 @@ +from typing import Optional + +import strawberry + +import strawberry_django +from strawberry_django.optimizer import DjangoOptimizerExtension +from strawberry_django.pagination import OffsetPaginated + +from .models import ( + AndroidProject, + AppProject, + ArtProject, + Company, + EngineeringProject, + IOSProject, + Project, + ResearchProject, + SoftwareProject, + TechnicalProject, +) + + +@strawberry_django.interface(Project) +class ProjectType: + topic: strawberry.auto + + +@strawberry_django.type(ArtProject) +class ArtProjectType(ProjectType): + artist: strawberry.auto + + +@strawberry_django.type(ResearchProject) +class ResearchProjectType(ProjectType): + supervisor: strawberry.auto + + +@strawberry_django.interface(TechnicalProject) +class TechnicalProjectType(ProjectType): + timeline: strawberry.auto + + +@strawberry_django.type(SoftwareProject) +class SoftwareProjectType(TechnicalProjectType): + repository: strawberry.auto + + +@strawberry_django.type(EngineeringProject) +class EngineeringProjectType(TechnicalProjectType): + lead_engineer: strawberry.auto + + +@strawberry_django.interface(AppProject) +class AppProjectType(TechnicalProjectType): + repository: strawberry.auto + + +@strawberry_django.type(AndroidProject) +class AndroidProjectType(AppProjectType): + android_version: strawberry.auto + + +@strawberry_django.type(IOSProject) +class IOSProjectType(AppProjectType): + ios_version: strawberry.auto + + +@strawberry_django.type(Company) +class CompanyType: + name: strawberry.auto + projects: list[ProjectType] + main_project: Optional[ProjectType] + + +@strawberry.type +class Query: + companies: list[CompanyType] = strawberry_django.field() + projects: list[ProjectType] = strawberry_django.field() + projects_paginated: list[ProjectType] = strawberry_django.field(pagination=True) + projects_offset_paginated: OffsetPaginated[ProjectType] = ( + strawberry_django.offset_paginated() + ) + + +schema = strawberry.Schema( + query=Query, + types=[ + ArtProjectType, + ResearchProjectType, + TechnicalProjectType, + EngineeringProjectType, + SoftwareProjectType, + AppProjectType, + IOSProjectType, + AndroidProjectType, + ], + extensions=[DjangoOptimizerExtension], +) diff --git a/tests/polymorphism_inheritancemanager/test_optimizer.py b/tests/polymorphism_inheritancemanager/test_optimizer.py new file mode 100644 index 00000000..873c7ffc --- /dev/null +++ b/tests/polymorphism_inheritancemanager/test_optimizer.py @@ -0,0 +1,437 @@ +import pytest +from django.db import DEFAULT_DB_ALIAS, connections +from django.test.utils import CaptureQueriesContext + +from tests.utils import assert_num_queries + +from .models import ( + AndroidProject, + ArtProject, + Company, + EngineeringProject, + IOSProject, + ResearchProject, + SoftwareProject, +) +from .schema import schema + + +@pytest.mark.django_db(transaction=True) +def test_polymorphic_interface_query(): + ap = ArtProject.objects.create(topic="Art", artist="Artist") + rp = ResearchProject.objects.create(topic="Research", supervisor="Supervisor") + + query = """\ + query { + projects { + __typename + topic + ... on ArtProjectType { + artist + } + ... on ResearchProjectType { + supervisor + } + } + } + """ + + with assert_num_queries(1): + result = schema.execute_sync(query) + assert not result.errors + assert result.data == { + "projects": [ + {"__typename": "ArtProjectType", "topic": ap.topic, "artist": ap.artist}, + { + "__typename": "ResearchProjectType", + "topic": rp.topic, + "supervisor": rp.supervisor, + }, + ] + } + + +@pytest.mark.django_db(transaction=True) +def test_polymorphic_query_abstract_model(): + ap = ArtProject.objects.create(topic="Art", artist="Artist") + sp = SoftwareProject.objects.create( + topic="Software", repository="https://example.com", timeline="3 months" + ) + ep = EngineeringProject.objects.create( + topic="Engineering", lead_engineer="Elara Voss", timeline="6 years" + ) + + query = """\ + query { + projects { + __typename + topic + ... on ArtProjectType { + artist + } + ...on TechnicalProjectType { + timeline + } + ... on SoftwareProjectType { + repository + } + ...on EngineeringProjectType { + leadEngineer + } + } + } + """ + + with assert_num_queries(1): + result = schema.execute_sync(query) + assert not result.errors + assert result.data == { + "projects": [ + {"__typename": "ArtProjectType", "topic": ap.topic, "artist": ap.artist}, + { + "__typename": "SoftwareProjectType", + "topic": sp.topic, + "repository": sp.repository, + "timeline": sp.timeline, + }, + { + "__typename": "EngineeringProjectType", + "topic": ep.topic, + "leadEngineer": ep.lead_engineer, + "timeline": ep.timeline, + }, + ] + } + + +@pytest.mark.django_db(transaction=True) +def test_polymorphic_query_multiple_inheritance_levels(): + app1 = AndroidProject.objects.create( + topic="Software", + repository="https://example.com/android", + timeline="3 months", + android_version="14", + ) + app2 = IOSProject.objects.create( + topic="Software", + repository="https://example.com/ios", + timeline="5 months", + ios_version="16", + ) + ep = EngineeringProject.objects.create( + topic="Engineering", lead_engineer="Elara Voss", timeline="6 years" + ) + + query = """\ + query { + projects { + __typename + topic + ...on TechnicalProjectType { + timeline + } + ...on AppProjectType { + repository + } + ...on AndroidProjectType { + androidVersion + } + ...on IOSProjectType { + iosVersion + } + ...on EngineeringProjectType { + leadEngineer + } + } + } + """ + + with assert_num_queries(1): + result = schema.execute_sync(query) + assert not result.errors + assert result.data == { + "projects": [ + { + "__typename": "AndroidProjectType", + "topic": app1.topic, + "repository": app1.repository, + "timeline": app1.timeline, + "androidVersion": app1.android_version, + }, + { + "__typename": "IOSProjectType", + "topic": app2.topic, + "repository": app2.repository, + "timeline": app2.timeline, + "iosVersion": app2.ios_version, + }, + { + "__typename": "EngineeringProjectType", + "topic": ep.topic, + "leadEngineer": ep.lead_engineer, + "timeline": ep.timeline, + }, + ] + } + + +@pytest.mark.django_db(transaction=True) +def test_polymorphic_query_abstract_model_on_field(): + ep = EngineeringProject.objects.create( + topic="Engineering", lead_engineer="Elara Voss", timeline="6 years" + ) + company = Company.objects.create(name="Company", main_project=ep) + + query = """\ + query { + companies { + name + mainProject { + __typename + topic + ...on TechnicalProjectType { + timeline + } + ...on EngineeringProjectType { + leadEngineer + } + } + } + } + """ + + with assert_num_queries(2): + result = schema.execute_sync(query) + assert not result.errors + assert result.data == { + "companies": [ + { + "name": company.name, + "mainProject": { + "__typename": "EngineeringProjectType", + "topic": ep.topic, + "leadEngineer": ep.lead_engineer, + "timeline": ep.timeline, + }, + } + ] + } + + +@pytest.mark.django_db(transaction=True) +def test_polymorphic_query_optimization_working(): + ap = ArtProject.objects.create(topic="Art", artist="Artist") + rp = ResearchProject.objects.create(topic="Research", supervisor="Supervisor") + + query = """\ + query { + projects { + __typename + topic + ... on ArtProjectType { + artist + } + ... on ResearchProjectType { + supervisor + } + } + } + """ + + with CaptureQueriesContext(connection=connections[DEFAULT_DB_ALIAS]) as ctx: + result = schema.execute_sync(query) + # validate that we're not selecting extra fields + assert not any("research_notes" in q for q in ctx.captured_queries) + assert not any("art_style" in q for q in ctx.captured_queries) + assert not result.errors + assert result.data == { + "projects": [ + {"__typename": "ArtProjectType", "topic": ap.topic, "artist": ap.artist}, + { + "__typename": "ResearchProjectType", + "topic": rp.topic, + "supervisor": rp.supervisor, + }, + ] + } + + +@pytest.mark.django_db(transaction=True) +def test_polymorphic_paginated_query(): + ap = ArtProject.objects.create(topic="Art", artist="Artist") + rp = ResearchProject.objects.create(topic="Research", supervisor="Supervisor") + + query = """\ + query { + projectsPaginated { + __typename + topic + ... on ArtProjectType { + artist + } + ... on ResearchProjectType { + supervisor + } + } + } + """ + + with assert_num_queries(1): + result = schema.execute_sync(query) + assert not result.errors + assert result.data == { + "projectsPaginated": [ + {"__typename": "ArtProjectType", "topic": ap.topic, "artist": ap.artist}, + { + "__typename": "ResearchProjectType", + "topic": rp.topic, + "supervisor": rp.supervisor, + }, + ] + } + + +@pytest.mark.django_db(transaction=True) +def test_polymorphic_offset_paginated_query(): + ap = ArtProject.objects.create(topic="Art", artist="Artist") + rp = ResearchProject.objects.create(topic="Research", supervisor="Supervisor") + + query = """\ + query { + projectsOffsetPaginated { + totalCount + results { + __typename + topic + ... on ArtProjectType { + artist + } + ... on ResearchProjectType { + supervisor + } + } + } + } + """ + + with assert_num_queries(2): + result = schema.execute_sync(query) + assert not result.errors + assert result.data == { + "projectsOffsetPaginated": { + "totalCount": 2, + "results": [ + { + "__typename": "ArtProjectType", + "topic": ap.topic, + "artist": ap.artist, + }, + { + "__typename": "ResearchProjectType", + "topic": rp.topic, + "supervisor": rp.supervisor, + }, + ], + } + } + + +@pytest.mark.django_db(transaction=True) +def test_polymorphic_relation(): + ap = ArtProject.objects.create(topic="Art", artist="Artist") + art_company = Company.objects.create(name="ArtCompany", main_project=ap) + + rp = ResearchProject.objects.create(topic="Research", supervisor="Supervisor") + research_company = Company.objects.create(name="ResearchCompany", main_project=rp) + + query = """\ + query { + companies { + name + mainProject { + __typename + topic + ... on ArtProjectType { + artist + } + ... on ResearchProjectType { + supervisor + } + } + } + } + """ + + with assert_num_queries(2): + result = schema.execute_sync(query) + assert not result.errors + assert result.data == { + "companies": [ + { + "name": art_company.name, + "mainProject": { + "__typename": "ArtProjectType", + "topic": ap.topic, + "artist": ap.artist, + }, + }, + { + "name": research_company.name, + "mainProject": { + "__typename": "ResearchProjectType", + "topic": rp.topic, + "supervisor": rp.supervisor, + }, + }, + ] + } + + +@pytest.mark.django_db(transaction=True) +def test_polymorphic_nested_list(): + company = Company.objects.create(name="Company") + ap = ArtProject.objects.create(company=company, topic="Art", artist="Artist") + rp = ResearchProject.objects.create( + company=company, topic="Research", supervisor="Supervisor" + ) + + query = """\ + query { + companies { + name + projects { + __typename + topic + ... on ArtProjectType { + artist + } + ... on ResearchProjectType { + supervisor + } + } + } + } + """ + + with assert_num_queries(2): + result = schema.execute_sync(query) + assert not result.errors + assert result.data == { + "companies": [ + { + "name": "Company", + "projects": [ + { + "__typename": "ArtProjectType", + "topic": ap.topic, + "artist": ap.artist, + }, + { + "__typename": "ResearchProjectType", + "topic": rp.topic, + "supervisor": rp.supervisor, + }, + ], + } + ] + }