Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
119 changes: 118 additions & 1 deletion src/main/java/graphql/schema/idl/SchemaTypeDirectivesChecker.java
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,17 @@
import graphql.schema.idl.errors.MissingTypeError;
import graphql.schema.idl.errors.NotAnInputTypeError;

import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.LinkedHashMap;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;

import static graphql.Assert.assertNotNull;
import static graphql.introspection.Introspection.DirectiveLocation.ARGUMENT_DEFINITION;
import static graphql.introspection.Introspection.DirectiveLocation.ENUM;
import static graphql.introspection.Introspection.DirectiveLocation.ENUM_VALUE;
Expand Down Expand Up @@ -182,6 +188,10 @@ private static boolean isNoNullArgWithoutDefaultValue(InputValueDefinition defin
}

private void commonCheck(Collection<DirectiveDefinition> directiveDefinitions, List<GraphQLError> errors) {
List<DirectiveDefinition> directiveDefinitionsList = new ArrayList<>(directiveDefinitions);
Map<String, DirectiveDefinition> directiveDefinitionsByName = getByName(directiveDefinitionsList, DirectiveDefinition::getName, mergeFirst());
Map<String, Map<String, InputValueDefinition>> directiveReferencesByName = directiveReferencesByName(directiveDefinitionsByName);

directiveDefinitions.forEach(directiveDefinition -> {
assertTypeName(directiveDefinition, errors);
directiveDefinition.getInputValueDefinitions().forEach(inputValueDefinition -> {
Expand All @@ -192,6 +202,113 @@ private void commonCheck(Collection<DirectiveDefinition> directiveDefinitions, L
}
});
});
checkIndirectDirectiveCycles(directiveDefinitionsByName, directiveReferencesByName, errors);
}

private static Map<String, Map<String, InputValueDefinition>> directiveReferencesByName(
Map<String, DirectiveDefinition> directiveDefinitionsByName) {
Map<String, Map<String, InputValueDefinition>> result = new LinkedHashMap<>();
directiveDefinitionsByName.forEach((name, directiveDefinition) -> result.put(name, directiveReferences(directiveDefinition)));
return result;
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you can use graphql.util.FpKit#getByName(java.util.List, java.util.function.Function<T,java.lang.String>, java.util.function.BinaryOperator)

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done in c58ab2a. directiveReferencesByName(...) now consumes the map produced by FpKit#getByName(..., mergeFirst()) instead of rebuilding the directive-definition map itself, so the cycle checker uses the shared local utility as suggested.

}

private static Map<String, InputValueDefinition> directiveReferences(DirectiveDefinition directiveDefinition) {
Map<String, InputValueDefinition> result = new LinkedHashMap<>();
for (InputValueDefinition inputValueDefinition : directiveDefinition.getInputValueDefinitions()) {
recordDirectiveReferences(directiveDefinition, result, inputValueDefinition);
}
return result;
}

private static void recordDirectiveReferences(DirectiveDefinition directiveDefinition,
Map<String, InputValueDefinition> result,
InputValueDefinition inputValueDefinition) {
for (Directive directive : inputValueDefinition.getDirectives()) {
if (directive.getName().equals(directiveDefinition.getName())) {
continue;
}
result.putIfAbsent(directive.getName(), inputValueDefinition);
}
}

private static void checkIndirectDirectiveCycles(
Map<String, DirectiveDefinition> directiveDefinitionsByName,
Map<String, Map<String, InputValueDefinition>> directiveReferencesByName,
List<GraphQLError> errors) {
Set<String> checked = new LinkedHashSet<>();
Set<String> visiting = new LinkedHashSet<>();
List<String> path = new ArrayList<>();
for (String directiveName : directiveDefinitionsByName.keySet()) {
checkIndirectDirectiveCycles(directiveName, directiveDefinitionsByName, directiveReferencesByName, checked, visiting, path, errors);
}
}

private static void checkIndirectDirectiveCycles(String directiveName,
Map<String, DirectiveDefinition> directiveDefinitionsByName,
Map<String, Map<String, InputValueDefinition>> directiveReferencesByName,
Set<String> checked,
Set<String> visiting,
List<String> path,
List<GraphQLError> errors) {
if (checked.contains(directiveName)) {
return;
}

visiting.add(directiveName);
path.add(directiveName);
checkIndirectDirectiveCycleReferences(directiveName, directiveDefinitionsByName, directiveReferencesByName, checked, visiting, path, errors);
path.remove(path.size() - 1);
visiting.remove(directiveName);
checked.add(directiveName);
}

private static void checkIndirectDirectiveCycleReferences(String directiveName,
Map<String, DirectiveDefinition> directiveDefinitionsByName,
Map<String, Map<String, InputValueDefinition>> directiveReferencesByName,
Set<String> checked,
Set<String> visiting,
List<String> path,
List<GraphQLError> errors) {
Map<String, InputValueDefinition> references = directiveReferencesByName.getOrDefault(directiveName, Collections.emptyMap());
for (Map.Entry<String, InputValueDefinition> entry : references.entrySet()) {
checkIndirectDirectiveCycleReference(entry.getKey(), entry.getValue(), directiveDefinitionsByName, directiveReferencesByName, checked, visiting, path, errors);
}
}

private static void checkIndirectDirectiveCycleReference(String referencedDirectiveName,
InputValueDefinition inputValueDefinition,
Map<String, DirectiveDefinition> directiveDefinitionsByName,
Map<String, Map<String, InputValueDefinition>> directiveReferencesByName,
Set<String> checked,
Set<String> visiting,
List<String> path,
List<GraphQLError> errors) {
if (visiting.contains(referencedDirectiveName)) {
addIndirectDirectiveCycleError(referencedDirectiveName, inputValueDefinition, directiveDefinitionsByName, path, errors);
return;
}
if (!checked.contains(referencedDirectiveName)) {
checkIndirectDirectiveCycles(referencedDirectiveName, directiveDefinitionsByName, directiveReferencesByName, checked, visiting, path, errors);
}
}

private static void addIndirectDirectiveCycleError(String repeatedDirectiveName,
InputValueDefinition inputValueDefinition,
Map<String, DirectiveDefinition> directiveDefinitionsByName,
List<String> path,
List<GraphQLError> errors) {
List<String> cyclePath = directiveCyclePath(repeatedDirectiveName, path);
String cyclePathString = String.join(" -> ", cyclePath);

DirectiveDefinition directiveDefinition = assertNotNull(directiveDefinitionsByName.get(repeatedDirectiveName));
errors.add(new DirectiveIllegalReferenceError(directiveDefinition, inputValueDefinition, cyclePathString));
}

private static List<String> directiveCyclePath(String repeatedDirectiveName, List<String> path) {
int cycleStart = path.indexOf(repeatedDirectiveName);
List<String> cyclePath = new ArrayList<>(path.subList(cycleStart, path.size()));
cyclePath.add(repeatedDirectiveName);
return cyclePath;
}

private static void assertTypeName(NamedNode<?> node, List<GraphQLError> errors) {
Expand Down Expand Up @@ -224,4 +341,4 @@ private static TypeDefinition<?> findTypeDefFromRegistry(String typeName, TypeDe
}
return typeRegistry.scalars().get(typeName);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -12,4 +12,11 @@ public DirectiveIllegalReferenceError(DirectiveDefinition directive, NamedNode l
directive.getName(), location.getName(), lineCol(location)
));
}
}

public DirectiveIllegalReferenceError(DirectiveDefinition directive, NamedNode location, String cyclePath) {
super(directive,
String.format("'%s' must not reference itself via directive cycle '%s' on '%s''%s'",
directive.getName(), cyclePath, location.getName(), lineCol(location)
));
}
}
41 changes: 41 additions & 0 deletions src/test/groovy/graphql/schema/idl/SchemaGeneratorTest.groovy
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import graphql.schema.GraphQLType
import graphql.schema.GraphQLTypeUtil
import graphql.schema.GraphQLUnionType
import graphql.schema.GraphqlTypeComparatorRegistry
import graphql.schema.idl.errors.DirectiveIllegalReferenceError
import graphql.schema.idl.errors.NotAnInputTypeError
import graphql.schema.idl.errors.NotAnOutputTypeError
import graphql.schema.idl.errors.SchemaProblem
Expand Down Expand Up @@ -2270,6 +2271,46 @@ class SchemaGeneratorTest extends Specification {
schema != null
}

def "#4201 indirect cyclical directive definitions are rejected without stack overflow - #name"() {
given:
def registry = new SchemaParser().parse(sdl)

when:
UnExecutableSchemaGenerator.makeUnExecutableSchema(registry)

then:
def e = thrown(SchemaProblem)
e.errors.size() == 1
e.errors.get(0) instanceof DirectiveIllegalReferenceError
e.errors.get(0).getMessage().contains(cycleMessage)

where:
name << ["two directives", "three directives"]
sdl << [
'''
directive @foo(x: Int @bar(y: 1)) on FIELD_DEFINITION | ARGUMENT_DEFINITION
directive @bar(y: Int @foo(x: 2)) on FIELD_DEFINITION | ARGUMENT_DEFINITION

type Query {
field: String @foo(x: 10) @bar(y: 20)
}
''',
'''
directive @dirA(x: Int @dirB(y: 1)) on FIELD_DEFINITION | ARGUMENT_DEFINITION
directive @dirB(y: Int @dirC(z: 2)) on FIELD_DEFINITION | ARGUMENT_DEFINITION
directive @dirC(z: Int @dirA(x: 3)) on FIELD_DEFINITION | ARGUMENT_DEFINITION

type Query {
field: String @dirA(x: 10) @dirB(y: 20) @dirC(z: 30)
}
'''
]
cycleMessage << [
"'foo' must not reference itself via directive cycle 'foo -> bar -> foo'",
"'dirA' must not reference itself via directive cycle 'dirA -> dirB -> dirC -> dirA'"
]
}

def "code registry default data fetcher is respected"() {
def sdl = '''
type Query {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,51 @@ class SchemaTypeDirectivesCheckerTest extends Specification {
errors.get(0).getMessage() == "'invalidExample' must not reference itself on 'arg''[@2:39]'"
}

def "directive must not indirectly reference itself"() {
given:
def spec = '''
directive @foo(arg: String @bar) on ARGUMENT_DEFINITION
directive @bar(arg: String @foo) on ARGUMENT_DEFINITION

type Query {
f1 : String
}
'''
def registry = parse(spec)
def errors = []

when:
new SchemaTypeDirectivesChecker(registry, RuntimeWiring.newRuntimeWiring().build()).checkTypeDirectives(errors)

then:
errors.size() == 1
errors.get(0) instanceof DirectiveIllegalReferenceError
errors.get(0).getMessage().contains("'foo' must not reference itself via directive cycle 'foo -> bar -> foo'")
}

def "directive must not indirectly reference itself through a longer cycle"() {
given:
def spec = '''
directive @dirA(x: Int @dirB(y: 1)) on ARGUMENT_DEFINITION
directive @dirB(y: Int @dirC(z: 2)) on ARGUMENT_DEFINITION
directive @dirC(z: Int @dirA(x: 3)) on ARGUMENT_DEFINITION

type Query {
f1 : String
}
'''
def registry = parse(spec)
def errors = []

when:
new SchemaTypeDirectivesChecker(registry, RuntimeWiring.newRuntimeWiring().build()).checkTypeDirectives(errors)

then:
errors.size() == 1
errors.get(0) instanceof DirectiveIllegalReferenceError
errors.get(0).getMessage().contains("'dirA' must not reference itself via directive cycle 'dirA -> dirB -> dirC -> dirA'")
}

def "directive must not begin with '__'"() {
given:
def spec = '''
Expand Down