Skip to content

Commit d79e79c

Browse files
d4rkendariuszkuc
authored andcommitted
Additional SchemaGeneratorHook that allows modifying GraphQLTypes. (#69)
Additional SchemaGeneratorHook that enables rewiring based on the directives. Closes #60
1 parent 11f2e9a commit d79e79c

File tree

6 files changed

+213
-45
lines changed

6 files changed

+213
-45
lines changed

pom.xml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@
7171

7272
<properties>
7373
<reflections.version>0.9.11</reflections.version>
74-
<kotlin.version>1.2.71</kotlin.version>
74+
<kotlin.version>1.3.10</kotlin.version>
7575
<kotlin-ktlint.version>0.29.0</kotlin-ktlint.version>
7676
<kotlin-detekt.version>1.0.0.RC8</kotlin-detekt.version>
7777
<mockk.version>1.8.9.kotlin13</mockk.version>

src/main/kotlin/com/expedia/graphql/schema/extensions/annotationExtensions.kt

Lines changed: 25 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ import graphql.schema.GraphQLArgument
1111
import graphql.schema.GraphQLDirective
1212
import graphql.schema.GraphQLInputType
1313
import kotlin.reflect.KAnnotatedElement
14-
import kotlin.reflect.KClass
14+
import kotlin.reflect.KParameter
1515
import kotlin.reflect.full.findAnnotation
1616
import com.expedia.graphql.annotations.GraphQLDirective as DirectiveAnnotation
1717

@@ -64,11 +64,10 @@ internal fun KAnnotatedElement.isGraphQLIgnored() = this.findAnnotation<GraphQLI
6464
internal fun KAnnotatedElement.isGraphQLID() = this.findAnnotation<GraphQLID>() != null
6565

6666
private fun Annotation.getDirectiveInfo(): DirectiveInfo? {
67-
val directiveAnnotation = this.annotationClass.annotations.find { it is DirectiveAnnotation } as? DirectiveAnnotation
68-
return when {
69-
directiveAnnotation != null -> DirectiveInfo(this.annotationClass.simpleName ?: "", directiveAnnotation)
70-
else -> null
71-
}
67+
return this.annotationClass.annotations
68+
.filterIsInstance(DirectiveAnnotation::class.java)
69+
.map { DirectiveInfo(this, it) }
70+
.firstOrNull()
7271
}
7372

7473
internal fun KAnnotatedElement.directives(hooks: SchemaGeneratorHooks) =
@@ -77,27 +76,32 @@ internal fun KAnnotatedElement.directives(hooks: SchemaGeneratorHooks) =
7776
.map { it.getGraphQLDirective(hooks) }
7877
.toList()
7978

79+
internal fun KParameter.directives(hooks: SchemaGeneratorHooks) =
80+
this.annotations.asSequence()
81+
.mapNotNull { it.getDirectiveInfo() }
82+
.map { it.getGraphQLDirective(hooks) }
83+
.toList()
84+
8085
@Throws(CouldNotGetNameOfAnnotationException::class)
8186
private fun DirectiveInfo.getGraphQLDirective(hooks: SchemaGeneratorHooks): GraphQLDirective {
82-
val kClass: KClass<out DirectiveAnnotation> = this.annotation.annotationClass
83-
val builder = GraphQLDirective.newDirective()
84-
val name: String = this.effectiveName ?: throw CouldNotGetNameOfAnnotationException(kClass)
87+
val directiveClass = this.directive.annotationClass
88+
val name: String = this.effectiveName ?: throw CouldNotGetNameOfAnnotationException(directiveClass)
8589

8690
@Suppress("Detekt.SpreadOperator")
91+
val builder = GraphQLDirective.newDirective()
92+
.name(name.normalizeDirectiveName())
93+
.validLocations(*this.directiveAnnotation.locations)
94+
.description(this.directiveAnnotation.description)
8795

88-
builder.name(name.normalizeDirectiveName())
89-
.validLocations(*this.annotation.locations)
90-
.description(this.annotation.description)
96+
directiveClass.getValidProperties(hooks).forEach { prop ->
97+
val propertyName = prop.name
98+
val value = prop.call(this.directive)
9199

92-
kClass.getValidFunctions(hooks).forEach { kFunction ->
93-
val propertyName = kFunction.name
94-
val value = kFunction.call(kClass)
95-
@Suppress("Detekt.UnsafeCast")
96-
val type = defaultGraphQLScalars(kFunction.returnType) as GraphQLInputType
100+
val type = defaultGraphQLScalars(prop.returnType) ?: hooks.willGenerateGraphQLType(prop.returnType)
97101
val argument = GraphQLArgument.newArgument()
98102
.name(propertyName)
99103
.value(value)
100-
.type(type)
104+
.type(type as? GraphQLInputType)
101105
.build()
102106
builder.argument(argument)
103107
}
@@ -107,10 +111,10 @@ private fun DirectiveInfo.getGraphQLDirective(hooks: SchemaGeneratorHooks): Grap
107111

108112
private fun String.normalizeDirectiveName() = CaseFormat.UPPER_CAMEL.to(CaseFormat.LOWER_CAMEL, this)
109113

110-
private data class DirectiveInfo(private val name: String, val annotation: DirectiveAnnotation) {
114+
private data class DirectiveInfo(val directive: Annotation, val directiveAnnotation: DirectiveAnnotation) {
111115
val effectiveName: String? = when {
112-
annotation.name.isNotEmpty() -> annotation.name
113-
name.isNotEmpty() -> name
116+
directiveAnnotation.name.isNotEmpty() -> directiveAnnotation.name
117+
directive.annotationClass.simpleName.isNullOrEmpty().not() -> directive.annotationClass.simpleName
114118
else -> null
115119
}
116120
}

src/main/kotlin/com/expedia/graphql/schema/generator/SchemaGenerator.kt

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -150,7 +150,8 @@ internal class SchemaGenerator(
150150

151151
val monadType = config.hooks.willResolveMonad(fn.returnType)
152152
builder.type(graphQLTypeOf(monadType) as GraphQLOutputType)
153-
return builder.build()
153+
val graphQLType = builder.build()
154+
return config.hooks.onRewireGraphQLType(monadType, graphQLType) as GraphQLFieldDefinition
154155
}
155156

156157
private fun property(prop: KProperty<*>): GraphQLFieldDefinition {
@@ -162,20 +163,33 @@ internal class SchemaGenerator(
162163
.type(propertyType)
163164
.deprecate(prop.getDeprecationReason())
164165

165-
return if (config.dataFetcherFactory != null && prop.isLateinit) {
166+
prop.directives(config.hooks).forEach {
167+
fieldBuilder.withDirective(it)
168+
state.directives.add(it)
169+
}
170+
171+
val field = if (config.dataFetcherFactory != null && prop.isLateinit) {
166172
updatePropertyFieldBuilder(propertyType, fieldBuilder, config.dataFetcherFactory)
167173
} else {
168174
fieldBuilder
169175
}.build()
176+
177+
return config.hooks.onRewireGraphQLType(prop.returnType, field) as GraphQLFieldDefinition
170178
}
171179

172180
private fun argument(parameter: KParameter): GraphQLArgument {
173181
parameter.throwIfUnathorizedInterface()
174-
return GraphQLArgument.newArgument()
182+
val builder = GraphQLArgument.newArgument()
175183
.name(parameter.name)
176184
.description(parameter.graphQLDescription() ?: parameter.type.graphQLDescription())
177185
.type(graphQLTypeOf(parameter.type, true) as GraphQLInputType)
178-
.build()
186+
187+
parameter.directives(config.hooks).forEach {
188+
builder.withDirective(it)
189+
state.directives.add(it)
190+
}
191+
192+
return config.hooks.onRewireGraphQLType(parameter.type, builder.build()) as GraphQLArgument
179193
}
180194

181195
private fun graphQLTypeOf(type: KType, inputType: Boolean = false, annotatedAsID: Boolean = false): GraphQLType {
Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
package com.expedia.graphql.schema.generator.directive
2+
3+
import graphql.Assert.assertNotNull
4+
import graphql.schema.GraphQLDirectiveContainer
5+
import graphql.schema.GraphQLFieldDefinition
6+
import graphql.schema.GraphQLObjectType
7+
import graphql.schema.GraphQLInterfaceType
8+
import graphql.schema.GraphQLUnionType
9+
import graphql.schema.GraphQLScalarType
10+
import graphql.schema.GraphQLEnumType
11+
import graphql.schema.GraphQLEnumValueDefinition
12+
import graphql.schema.GraphQLArgument
13+
import graphql.schema.GraphQLInputObjectField
14+
import graphql.schema.GraphQLInputObjectType
15+
import graphql.schema.GraphQLDirective
16+
import graphql.schema.GraphQLType
17+
import graphql.schema.idl.SchemaDirectiveWiring
18+
import graphql.schema.idl.SchemaDirectiveWiringEnvironment
19+
import graphql.schema.idl.SchemaDirectiveWiringEnvironmentImpl
20+
import graphql.schema.idl.WiringFactory
21+
22+
/**
23+
* Based on
24+
* https://github.com/graphql-java/graphql-java/blob/master/src/main/java/graphql/schema/idl/SchemaGeneratorDirectiveHelper.java
25+
*/
26+
class DirectiveWiringHelper(private val wiringFactory: WiringFactory, private val manualWiring: Map<String, SchemaDirectiveWiring> = mutableMapOf()) {
27+
28+
@Suppress("UNCHECKED_CAST", "Detekt.ComplexMethod")
29+
fun onWire(generatedType: GraphQLType): GraphQLType {
30+
if (generatedType !is GraphQLDirectiveContainer) return generatedType
31+
32+
return wireDirectives(generatedType, getDirectives(generatedType),
33+
{ outputElement, directive -> createWiringEnvironment(outputElement, directive) },
34+
{ wiring, environment ->
35+
when (environment.element) {
36+
is GraphQLObjectType -> wiring.onObject(environment as SchemaDirectiveWiringEnvironment<GraphQLObjectType>)
37+
is GraphQLFieldDefinition -> wiring.onField(environment as SchemaDirectiveWiringEnvironment<GraphQLFieldDefinition>)
38+
is GraphQLInterfaceType -> wiring.onInterface(environment as SchemaDirectiveWiringEnvironment<GraphQLInterfaceType>)
39+
is GraphQLUnionType -> wiring.onUnion(environment as SchemaDirectiveWiringEnvironment<GraphQLUnionType>)
40+
is GraphQLScalarType -> wiring.onScalar(environment as SchemaDirectiveWiringEnvironment<GraphQLScalarType>)
41+
is GraphQLEnumType -> wiring.onEnum(environment as SchemaDirectiveWiringEnvironment<GraphQLEnumType>)
42+
is GraphQLEnumValueDefinition -> wiring.onEnumValue(environment as SchemaDirectiveWiringEnvironment<GraphQLEnumValueDefinition>)
43+
is GraphQLArgument -> wiring.onArgument(environment as SchemaDirectiveWiringEnvironment<GraphQLArgument>)
44+
is GraphQLInputObjectType -> wiring.onInputObjectType(environment as SchemaDirectiveWiringEnvironment<GraphQLInputObjectType>)
45+
is GraphQLInputObjectField -> wiring.onInputObjectField(environment as SchemaDirectiveWiringEnvironment<GraphQLInputObjectField>)
46+
else -> generatedType
47+
}
48+
}
49+
)
50+
}
51+
52+
private fun getDirectives(generatedType: GraphQLDirectiveContainer): MutableList<GraphQLDirective> {
53+
// A function without directives may still be rewired if the arguments have directives
54+
val directives = generatedType.directives
55+
if (generatedType is GraphQLFieldDefinition) {
56+
generatedType.arguments.forEach { directives.addAll(it.directives) }
57+
}
58+
return directives
59+
}
60+
61+
private fun <T : GraphQLDirectiveContainer> createWiringEnvironment(element: T, directive: GraphQLDirective): SchemaDirectiveWiringEnvironment<T> =
62+
SchemaDirectiveWiringEnvironmentImpl(element, directive, null, null, null)
63+
64+
private fun <T : GraphQLDirectiveContainer> wireDirectives(
65+
element: T,
66+
directives: List<GraphQLDirective>,
67+
envBuilder: (T, GraphQLDirective) -> SchemaDirectiveWiringEnvironment<T>,
68+
invoker: (SchemaDirectiveWiring, SchemaDirectiveWiringEnvironment<T>) -> T
69+
): T {
70+
var outputObject = element
71+
for (directive in directives) {
72+
val env = envBuilder.invoke(outputObject, directive)
73+
val directiveWiring = discoverWiringProvider(directive.name, env)
74+
if (directiveWiring != null) {
75+
val newElement = invoker.invoke(directiveWiring, env)
76+
assertNotNull(newElement, "The SchemaDirectiveWiring MUST return a non null return value for element '" + element.name + "'")
77+
outputObject = newElement
78+
}
79+
}
80+
return outputObject
81+
}
82+
83+
private fun <T : GraphQLDirectiveContainer> discoverWiringProvider(directiveName: String, env: SchemaDirectiveWiringEnvironment<T>): SchemaDirectiveWiring? {
84+
return if (wiringFactory.providesSchemaDirectiveWiring(env)) {
85+
wiringFactory.getSchemaDirectiveWiring(env)
86+
} else {
87+
manualWiring[directiveName]
88+
}
89+
}
90+
}

src/main/kotlin/com/expedia/graphql/schema/hooks/SchemaGeneratorHooks.kt

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,12 @@ interface SchemaGeneratorHooks {
5353
@Suppress("Detekt.FunctionOnlyReturningConstant")
5454
fun isValidFunction(function: KFunction<*>): Boolean = true
5555

56+
/**
57+
* Called after `willGenerateGraphQLType` and before `didGenerateGraphQLType`.
58+
* Enables you to change the wiring, e.g. directives to alter data fetchers.
59+
*/
60+
fun onRewireGraphQLType(type: KType, generatedType: GraphQLType): GraphQLType = generatedType
61+
5662
/**
5763
* Called after wrapping the type based on nullity but before adding the generated type to the schema
5864
*/

src/test/kotlin/com/expedia/graphql/schema/generator/DirectiveTests.kt

Lines changed: 73 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ import com.expedia.graphql.TopLevelObjectDef
44
import com.expedia.graphql.annotations.GraphQLDirective
55
import com.expedia.graphql.schema.testSchemaConfig
66
import com.expedia.graphql.toSchema
7+
import graphql.Scalars
78
import graphql.introspection.Introspection
89
import graphql.schema.GraphQLInputObjectType
910
import graphql.schema.GraphQLNonNull
@@ -59,53 +60,106 @@ class DirectiveTests {
5960

6061
@Test
6162
@Suppress("Detekt.UnsafeCast")
62-
fun `SchemaGenerator creates directives`() {
63+
fun `Directive renaming`() {
6364
val schema = toSchema(listOf(TopLevelObjectDef(QueryObject())), config = testSchemaConfig)
6465

65-
val geographyType = schema.getType("Geography") as? GraphQLObjectType
66-
assertNotNull(geographyType?.getDirective("whatever"))
67-
assertNotNull(geographyType?.getFieldDefinition("somethingCool")?.getDirective("directiveOnFunction"))
68-
assertNotNull((schema.getType("Location") as? GraphQLObjectType)?.getDirective("renamedDirective"))
69-
assertNotNull(schema.getDirective("whatever"))
70-
assertNotNull(schema.getDirective("renamedDirective"))
71-
val directiveOnFunction = schema.getDirective("directiveOnFunction")
72-
assertNotNull(directiveOnFunction)
66+
val renamedDirective = assertNotNull(
67+
(schema.getType("Location") as? GraphQLObjectType)
68+
?.getDirective("rightNameDirective")
69+
)
70+
71+
assertEquals("arenaming", renamedDirective.arguments[0].value)
72+
assertEquals("arg", renamedDirective.arguments[0].name)
73+
assertEquals(Scalars.GraphQLString, renamedDirective.arguments[0].type)
74+
}
75+
76+
@Test
77+
@Suppress("Detekt.UnsafeCast")
78+
fun `Directives on classes`() {
79+
val schema = toSchema(listOf(TopLevelObjectDef(QueryObject())), config = testSchemaConfig)
80+
81+
val directive = assertNotNull(
82+
(schema.getType("Geography") as? GraphQLObjectType)
83+
?.getDirective("onClassDirective")
84+
)
85+
86+
assertEquals("aclass", directive.arguments[0].value)
87+
assertEquals("arg", directive.arguments[0].name)
88+
assertEquals(Scalars.GraphQLString, directive.arguments[0].type)
89+
}
90+
91+
@Test
92+
@Suppress("Detekt.UnsafeCast")
93+
fun `Directives on functions`() {
94+
val schema = toSchema(listOf(TopLevelObjectDef(QueryObject())), config = testSchemaConfig)
95+
96+
val directive = assertNotNull(
97+
(schema.getType("Geography") as? GraphQLObjectType)
98+
?.getFieldDefinition("somethingCool")
99+
?.getDirective("onFunctionDirective")
100+
)
101+
102+
assertEquals("afunction", directive.arguments[0].value)
103+
assertEquals("arg", directive.arguments[0].name)
104+
assertEquals(Scalars.GraphQLString, directive.arguments[0].type)
105+
106+
assertNotNull(directive)
73107
assertEquals(
74-
directiveOnFunction.validLocations()?.toSet(),
108+
directive.validLocations()?.toSet(),
75109
setOf(Introspection.DirectiveLocation.FIELD_DEFINITION, Introspection.DirectiveLocation.FIELD)
76110
)
77111
}
112+
113+
@Test
114+
@Suppress("Detekt.UnsafeCast")
115+
fun `Directives on arguments`() {
116+
val schema = toSchema(listOf(TopLevelObjectDef(QueryObject())), config = testSchemaConfig)
117+
118+
val directive = assertNotNull(
119+
schema.queryType
120+
.getFieldDefinition("query")
121+
.getArgument("value")
122+
.getDirective("onArgumentDirective")
123+
)
124+
125+
assertEquals("anargument", directive.arguments[0].value)
126+
assertEquals("arg", directive.arguments[0].name)
127+
assertEquals(Scalars.GraphQLString, directive.arguments[0].type)
128+
}
78129
}
79130

131+
@GraphQLDirective(name = "RightNameDirective")
132+
annotation class WrongNameDirective(val arg: String)
133+
80134
@GraphQLDirective
81-
annotation class Whatever
135+
annotation class OnClassDirective(val arg: String)
82136

83-
@GraphQLDirective(locations = [Introspection.DirectiveLocation.FIELD_DEFINITION, Introspection.DirectiveLocation.FIELD])
84-
annotation class DirectiveOnFunction
137+
@GraphQLDirective
138+
annotation class OnArgumentDirective(val arg: String)
85139

86-
@GraphQLDirective(name = "RenamedDirective")
87-
annotation class RenamedDirective(val x: Boolean)
140+
@GraphQLDirective(locations = [Introspection.DirectiveLocation.FIELD_DEFINITION, Introspection.DirectiveLocation.FIELD])
141+
annotation class OnFunctionDirective(val arg: String)
88142

89-
@Whatever
143+
@OnClassDirective(arg = "aclass")
90144
class Geography(
91145
val id: Int?,
92146
val type: GeoType,
93147
val locations: List<Location>
94148
) {
95149
@Suppress("Detekt.FunctionOnlyReturningConstant")
96-
@DirectiveOnFunction
150+
@OnFunctionDirective(arg = "afunction")
97151
fun somethingCool(): String = "Something cool"
98152
}
99153

100154
enum class GeoType {
101155
CITY, STATE
102156
}
103157

104-
@RenamedDirective(x = false)
158+
@WrongNameDirective(arg = "arenaming")
105159
data class Location(val lat: Double, val lon: Double)
106160

107161
class QueryObject {
108-
fun query(value: Int): Geography = Geography(value, GeoType.CITY, listOf())
162+
fun query(@OnArgumentDirective(arg = "anargument") value: Int): Geography = Geography(value, GeoType.CITY, listOf())
109163
}
110164

111165
class QueryWithDeprecatedFields {

0 commit comments

Comments
 (0)