Skip to content

Commit ba5c349

Browse files
committed
Refactor get_possible_concrete_types into utils/inspect.py
1 parent 3ab8dfd commit ba5c349

File tree

2 files changed

+74
-58
lines changed

2 files changed

+74
-58
lines changed

strawberry_django/optimizer.py

Lines changed: 11 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,8 @@
55
import copy
66
import dataclasses
77
import itertools
8-
from collections import Counter, defaultdict
9-
from collections.abc import Callable, Iterable
8+
from collections import Counter
9+
from collections.abc import Callable
1010
from typing import (
1111
TYPE_CHECKING,
1212
Any,
@@ -42,11 +42,11 @@
4242
from strawberry.schema.schema import Schema
4343
from strawberry.schema.schema_converter import get_arguments
4444
from strawberry.types import get_object_definition
45-
from strawberry.types.base import StrawberryContainer, StrawberryType
45+
from strawberry.types.base import StrawberryContainer
4646
from strawberry.types.info import Info
4747
from strawberry.types.lazy_type import LazyType
4848
from strawberry.types.object_type import StrawberryObjectDefinition
49-
from typing_extensions import TypeGuard, assert_never, assert_type
49+
from typing_extensions import assert_never, assert_type
5050

5151
from strawberry_django.fields.types import resolve_model_field_name
5252
from strawberry_django.pagination import OffsetPaginated, apply_window_pagination
@@ -59,7 +59,9 @@
5959
PrefetchInspector,
6060
get_model_field,
6161
get_model_fields,
62+
get_possible_concrete_types,
6263
get_possible_type_definitions,
64+
is_polymorphic_model,
6365
)
6466
from .utils.typing import (
6567
AnnotateCallable,
@@ -77,7 +79,6 @@
7779
from collections.abc import Generator
7880

7981
from django.contrib.contenttypes.fields import GenericRelation
80-
from polymorphic.models import PolymorphicModel
8182
from strawberry.types.execution import ExecutionContext
8283
from strawberry.types.field import StrawberryField
8384
from strawberry.utils.await_maybe import AwaitableOrValue
@@ -96,20 +97,6 @@
9697

9798
_sentinel = object()
9899
_annotate_placeholder = "__annotated_placeholder__"
99-
_interfaces: defaultdict[
100-
Schema,
101-
dict[StrawberryObjectDefinition, list[StrawberryObjectDefinition]],
102-
] = defaultdict(
103-
dict,
104-
)
105-
106-
107-
def _is_polymorphic_model(v: type) -> TypeGuard[type[PolymorphicModel]]:
108-
try:
109-
from polymorphic.models import PolymorphicModel
110-
except ImportError:
111-
return False
112-
return issubclass(v, PolymorphicModel)
113100

114101

115102
@dataclasses.dataclass
@@ -1038,7 +1025,7 @@ def _get_model_hints(
10381025
# These must be prefixed with app_label__ModelName___ (note three underscores)
10391026
# This is a special syntax for django-polymorphic:
10401027
# https://django-polymorphic.readthedocs.io/en/stable/advanced.html#polymorphic-filtering-for-fields-in-inherited-classes
1041-
if _is_polymorphic_model(model) and issubclass(dj_definition.model, model):
1028+
if is_polymorphic_model(model) and issubclass(dj_definition.model, model):
10421029
return _get_model_hints(
10431030
dj_definition.model,
10441031
schema,
@@ -1060,7 +1047,7 @@ def _get_model_hints(
10601047
store.only.append(prefix + pk.attname)
10611048

10621049
# If this is a polymorphic Model, make sure to select its content type
1063-
if _is_polymorphic_model(model):
1050+
if is_polymorphic_model(model):
10641051
store.only.extend(prefix + f for f in model.polymorphic_internal_model_fields)
10651052

10661053
selections = [
@@ -1214,7 +1201,7 @@ def _get_model_hints_from_connection(
12141201
if node.name.value != "node":
12151202
continue
12161203

1217-
for concrete_n_type in _get_possible_concrete_types(
1204+
for concrete_n_type in get_possible_concrete_types(
12181205
model, schema, n_definition
12191206
):
12201207
n_gql_definition = _get_gql_definition(schema, concrete_n_type)
@@ -1271,9 +1258,7 @@ def _get_model_hints_from_paginated(
12711258
if selection.name.value != "results":
12721259
continue
12731260

1274-
for concrete_n_type in _get_possible_concrete_types(
1275-
model, schema, n_definition
1276-
):
1261+
for concrete_n_type in get_possible_concrete_types(model, schema, n_definition):
12771262
n_gql_definition = _get_gql_definition(
12781263
schema,
12791264
concrete_n_type,
@@ -1305,38 +1290,6 @@ def _get_model_hints_from_paginated(
13051290
return store
13061291

13071292

1308-
def _get_possible_concrete_types(
1309-
model: type[models.Model],
1310-
schema: Schema,
1311-
strawberry_type: StrawberryObjectDefinition | StrawberryType,
1312-
) -> Iterable[StrawberryObjectDefinition]:
1313-
for object_definition in get_possible_type_definitions(strawberry_type):
1314-
if object_definition.is_interface:
1315-
interface_definitions = _interfaces[schema].get(object_definition)
1316-
if interface_definitions is None:
1317-
interface_definitions = []
1318-
for t in schema.schema_converter.type_map.values():
1319-
t_definition = t.definition
1320-
if isinstance(
1321-
t_definition, StrawberryObjectDefinition
1322-
) and issubclass(t_definition.origin, object_definition.origin):
1323-
interface_definitions.append(t_definition)
1324-
_interfaces[schema][object_definition] = interface_definitions
1325-
1326-
for interface_definition in interface_definitions:
1327-
dj_definition = get_django_definition(interface_definition.origin)
1328-
if dj_definition and (
1329-
issubclass(model, dj_definition.model)
1330-
or (
1331-
_is_polymorphic_model(model)
1332-
and issubclass(dj_definition.model, model)
1333-
)
1334-
):
1335-
yield interface_definition
1336-
else:
1337-
yield object_definition
1338-
1339-
13401293
def optimize(
13411294
qs: QuerySet[_M] | BaseManager[_M],
13421295
info: GraphQLResolveInfo | Info,
@@ -1400,7 +1353,7 @@ def optimize(
14001353
if strawberry_type is None:
14011354
return qs
14021355

1403-
for inner_object_definition in _get_possible_concrete_types(
1356+
for inner_object_definition in get_possible_concrete_types(
14041357
qs.model, schema, strawberry_type
14051358
):
14061359
parent_type = _get_gql_definition(schema, inner_object_definition)

strawberry_django/utils/inspect.py

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,17 @@
33
import dataclasses
44
import functools
55
import itertools
6+
from collections import defaultdict
7+
from collections.abc import Iterable
68
from typing import (
79
TYPE_CHECKING,
10+
TypeGuard,
811
cast,
912
)
1013

1114
from django.db.models.query import Prefetch, QuerySet
1215
from django.db.models.sql.where import WhereNode
16+
from strawberry import Schema
1317
from strawberry.types import has_object_definition
1418
from strawberry.types.base import (
1519
StrawberryContainer,
@@ -25,6 +29,7 @@
2529
from strawberry_django.fields.types import resolve_model_field_name
2630

2731
from .pyutils import DictTree, dicttree_insersection_differs, dicttree_merge
32+
from .typing import get_django_definition
2833

2934
if TYPE_CHECKING:
3035
from collections.abc import Generator, Iterable
@@ -34,6 +39,7 @@
3439
from django.db.models.fields import Field
3540
from django.db.models.fields.reverse_related import ForeignObjectRel
3641
from django.db.models.sql.query import Query
42+
from polymorphic.models import PolymorphicModel
3743

3844

3945
@functools.lru_cache
@@ -143,6 +149,63 @@ def get_possible_type_definitions(
143149
yield t.__strawberry_definition__
144150

145151

152+
def is_polymorphic_model(v: type) -> TypeGuard[type[PolymorphicModel]]:
153+
try:
154+
from polymorphic.models import PolymorphicModel
155+
except ImportError:
156+
return False
157+
return issubclass(v, PolymorphicModel)
158+
159+
160+
def _can_optimize_subtypes(model: type[models.Model]) -> bool:
161+
return is_polymorphic_model(model)
162+
163+
164+
_interfaces: defaultdict[
165+
Schema,
166+
dict[StrawberryObjectDefinition, list[StrawberryObjectDefinition]],
167+
] = defaultdict(dict)
168+
169+
170+
def get_possible_concrete_types(
171+
model: type[models.Model],
172+
schema: Schema,
173+
strawberry_type: StrawberryObjectDefinition | StrawberryType,
174+
) -> Iterable[StrawberryObjectDefinition]:
175+
"""Return the object definitions the optimizer should look at when optimizing a model.
176+
177+
Returns any object definitions attached to either the model or one of its supertypes.
178+
179+
If the model is one that supports polymorphism, by returning subtypes from its queryset, subtypes are also
180+
looked at. Currently, this is only supported for django-polymorphic.
181+
"""
182+
for object_definition in get_possible_type_definitions(strawberry_type):
183+
if object_definition.is_interface:
184+
interface_definitions = _interfaces[schema].get(object_definition)
185+
if interface_definitions is None:
186+
interface_definitions = []
187+
for t in schema.schema_converter.type_map.values():
188+
t_definition = t.definition
189+
if isinstance(
190+
t_definition, StrawberryObjectDefinition
191+
) and issubclass(t_definition.origin, object_definition.origin):
192+
interface_definitions.append(t_definition)
193+
_interfaces[schema][object_definition] = interface_definitions
194+
195+
for interface_definition in interface_definitions:
196+
dj_definition = get_django_definition(interface_definition.origin)
197+
if dj_definition and (
198+
issubclass(model, dj_definition.model)
199+
or (
200+
_can_optimize_subtypes(model)
201+
and issubclass(dj_definition.model, model)
202+
)
203+
):
204+
yield interface_definition
205+
else:
206+
yield object_definition
207+
208+
146209
@dataclasses.dataclass(eq=True)
147210
class PrefetchInspector:
148211
"""Prefetch hints."""

0 commit comments

Comments
 (0)