Skip to content

Implement Display for nested shapes referenced in error messages #4081

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 5 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ import software.amazon.smithy.rust.codegen.core.smithy.generators.BuilderGenerat
import software.amazon.smithy.rust.codegen.core.smithy.generators.StructureGenerator
import software.amazon.smithy.rust.codegen.core.smithy.generators.UnionGenerator
import software.amazon.smithy.rust.codegen.core.smithy.protocols.ProtocolGeneratorFactory
import software.amazon.smithy.rust.codegen.core.smithy.transformers.AddSyntheticTraitForImplDisplay
import software.amazon.smithy.rust.codegen.core.smithy.transformers.EventStreamNormalizer
import software.amazon.smithy.rust.codegen.core.smithy.transformers.OperationNormalizer
import software.amazon.smithy.rust.codegen.core.smithy.transformers.RecursiveShapeBoxer
Expand Down Expand Up @@ -146,6 +147,9 @@ class ClientCodegenVisitor(
.let(EventStreamNormalizer::transform)
// Mark operations incompatible with stalled stream protection as such
.let(DisableStalledStreamProtection::transformModel)
// Add synthetic trait to shapes referenced by error types to ensure they implement `Display`.
// This ensures error formatting works correctly for nested structures.
.let(AddSyntheticTraitForImplDisplay::transform)

/**
* Execute code generation
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ import software.amazon.smithy.rust.codegen.client.smithy.generators.OperationCus
import software.amazon.smithy.rust.codegen.client.smithy.generators.OperationSection
import software.amazon.smithy.rust.codegen.client.smithy.generators.http.ResponseBindingGenerator
import software.amazon.smithy.rust.codegen.core.rustlang.Attribute
import software.amazon.smithy.rust.codegen.core.rustlang.RustType
import software.amazon.smithy.rust.codegen.core.rustlang.RustWriter
import software.amazon.smithy.rust.codegen.core.rustlang.Writable
import software.amazon.smithy.rust.codegen.core.rustlang.assignment
Expand All @@ -33,6 +34,7 @@ import software.amazon.smithy.rust.codegen.core.smithy.protocols.HttpLocation
import software.amazon.smithy.rust.codegen.core.smithy.protocols.Protocol
import software.amazon.smithy.rust.codegen.core.smithy.protocols.ProtocolFunctions
import software.amazon.smithy.rust.codegen.core.smithy.protocols.parse.StructuredDataParserGenerator
import software.amazon.smithy.rust.codegen.core.smithy.rustType
import software.amazon.smithy.rust.codegen.core.smithy.transformers.operationErrors
import software.amazon.smithy.rust.codegen.core.util.UNREACHABLE
import software.amazon.smithy.rust.codegen.core.util.dq
Expand Down Expand Up @@ -163,10 +165,11 @@ class ProtocolParserGenerator(
}
}
val errorMessageMember = errorShape.errorMessageMember()
// If the message member is optional and wasn't set, we set a generic error message.
// If the message member is optional, is of `String` Rust type and wasn't set, we set a generic error message.
if (errorMessageMember != null) {
val symbol = symbolProvider.toSymbol(errorMessageMember)
if (symbol.isOptional()) {
val currentRustType = symbol.rustType()
if (symbol.isOptional() && currentRustType == RustType.String) {
rust(
"""
if tmp.message.is_none() {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
package software.amazon.smithy.rust.codegen.client.smithy.generators

import org.junit.jupiter.api.Test
import software.amazon.smithy.rust.codegen.client.testutil.clientIntegrationTest
import software.amazon.smithy.rust.codegen.core.testutil.asSmithyModel
import java.io.File

class ClientErrorReachableShapesDisplayTest {
@Test
fun correctMissingFields() {
var sampleModel = File("../codegen-core/common-test-models/nested-error.smithy").readText().asSmithyModel()
clientIntegrationTest(sampleModel) { _, _ ->
// It should compile.
}
}
}
89 changes: 89 additions & 0 deletions codegen-core/common-test-models/nested-error.smithy
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
$version: "2"

namespace sample

use smithy.framework#ValidationException
use aws.protocols#restJson1

@restJson1
service SampleService {
operations: [SampleOperation]
}

@http(uri: "/anOperation", method: "POST")
operation SampleOperation {
output:= {}
input:= {}
errors: [
SimpleError,
ErrorInInput,
ErrorWithDeepCompositeShape,
ComposedSensitiveError,
]
}

@error("client")
structure SimpleError {
message: String
}

@error("client")
structure ErrorInInput {
message: ErrorMessage
}

structure ErrorMessage {
@required
statusCode: Integer
@required
errorMessage: String
@required
isRetryable: Boolean
requestId: String
timeStamp: Timestamp
ratio: Float
precision: Double
dataSize: Long
byteCount: Short
flags: Byte
documentData: Document
blobData: Blob
tags: Map
errorCodes: List
}

map Map {
key: String,
value: String
}

list List {
member: Integer
}

structure WrappedErrorMessage {
someValue: Integer
contained: ErrorMessage
}

@error("client")
structure ErrorWithDeepCompositeShape {
message: WrappedErrorMessage
}

@sensitive
structure SensitiveMessage {
nothing: String
should: String
bePrinted: String
}

@error("server")
structure ComposedSensitiveError {
message: SensitiveMessage
}

@error("server")
structure ErrorWithNestedError {
message: ErrorWithDeepCompositeShape
}
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,12 @@ package software.amazon.smithy.rust.codegen.core.smithy.generators

import software.amazon.smithy.codegen.core.Symbol
import software.amazon.smithy.model.Model
import software.amazon.smithy.model.shapes.BlobShape
import software.amazon.smithy.model.shapes.DocumentShape
import software.amazon.smithy.model.shapes.ListShape
import software.amazon.smithy.model.shapes.MapShape
import software.amazon.smithy.model.shapes.MemberShape
import software.amazon.smithy.model.shapes.StringShape
import software.amazon.smithy.model.shapes.StructureShape
import software.amazon.smithy.model.traits.ErrorTrait
import software.amazon.smithy.model.traits.SensitiveTrait
Expand All @@ -24,15 +29,21 @@ import software.amazon.smithy.rust.codegen.core.rustlang.isDeref
import software.amazon.smithy.rust.codegen.core.rustlang.render
import software.amazon.smithy.rust.codegen.core.rustlang.rust
import software.amazon.smithy.rust.codegen.core.rustlang.rustBlock
import software.amazon.smithy.rust.codegen.core.rustlang.rustBlockTemplate
import software.amazon.smithy.rust.codegen.core.rustlang.stripOuter
import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType
import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType.Companion.preludeScope
import software.amazon.smithy.rust.codegen.core.smithy.RustSymbolProvider
import software.amazon.smithy.rust.codegen.core.smithy.customize.NamedCustomization
import software.amazon.smithy.rust.codegen.core.smithy.customize.Section
import software.amazon.smithy.rust.codegen.core.smithy.customize.writeCustomizations
import software.amazon.smithy.rust.codegen.core.smithy.expectRustMetadata
import software.amazon.smithy.rust.codegen.core.smithy.isOptional
import software.amazon.smithy.rust.codegen.core.smithy.protocols.serialize.ValueExpression
import software.amazon.smithy.rust.codegen.core.smithy.renamedFrom
import software.amazon.smithy.rust.codegen.core.smithy.rustType
import software.amazon.smithy.rust.codegen.core.smithy.shape
import software.amazon.smithy.rust.codegen.core.smithy.traits.SyntheticImplDisplayTrait
import software.amazon.smithy.rust.codegen.core.util.REDACTION
import software.amazon.smithy.rust.codegen.core.util.dq
import software.amazon.smithy.rust.codegen.core.util.getTrait
Expand Down Expand Up @@ -128,6 +139,72 @@ open class StructureGenerator(
}
}

private fun renderImplDisplayIfSyntheticImplDisplayTraitApplied() {
if (shape.getTrait<SyntheticImplDisplayTrait>() == null) {
return
}

val lifetime = shape.lifetimeDeclaration(symbolProvider)
writer.rustBlock(
"impl ${shape.lifetimeDeclaration(symbolProvider)} #T for $name $lifetime",
RuntimeType.Display,
) {
writer.rustBlock("fn fmt(&self, f: &mut #1T::Formatter<'_>) -> #1T::Result", RuntimeType.stdFmt) {
write("""::std::write!(f, "$name {{")?;""")

var separator = ""
for (index in members.indices) {
val member = members[index]
val memberName = symbolProvider.toMemberName(member)
val memberSymbol = symbolProvider.toSymbol(member)

val shouldRedact = shape.shouldRedact(model) || member.shouldRedact(model)
// If the shape is redacted then each member shape will be redacted.
if (shouldRedact) {
write("""::std::write!(f, "$separator$memberName={}", $REDACTION)?;""")
} else {
val variable = ValueExpression.Reference("&self.$memberName")

val target = model.expectShape(member.target)
when (target) {
is DocumentShape, is BlobShape, is MapShape, is ListShape -> {
// Just print the member field name but not the value.
if (memberSymbol.isOptional()) {
rustBlockTemplate("if let #{Some}(_) = ${variable.asRef()}", *preludeScope) {
write("""::std::write!(f, "$separator$memberName=Some()")?;""")
}
rustBlock("else") {
write("""::std::write!(f, "$separator$memberName=None")?;""")
}
} else {
write("""::std::write!(f, "$separator$memberName=")?;""")
}
}
else -> {
if (memberSymbol.isOptional()) {
rustBlockTemplate("if let #{Some}(inner) = ${variable.asRef()}", *preludeScope) {
write("""::std::write!(f, "$separator$memberName=Some({})", inner)?;""")
}
rustBlock("else") {
write("""::std::write!(f, "$separator$memberName=None")?;""")
}
} else {
write("""::std::write!(f, "$separator$memberName={}", ${variable.asRef()})?;""")
}
}
}
}

if (separator.isEmpty()) {
separator = ", "
}
}

write("""::std::write!(f, "}}")""")
}
}
}

private fun renderStructureImpl() {
if (accessorMembers.isEmpty()) {
return
Expand Down Expand Up @@ -209,6 +286,7 @@ open class StructureGenerator(
if (!containerMeta.hasDebugDerive()) {
renderDebugImpl()
}
renderImplDisplayIfSyntheticImplDisplayTraitApplied()

writer.writeCustomizations(customizations, StructureSection.AdditionalTraitImpls(shape, name))
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
package software.amazon.smithy.rust.codegen.core.smithy.traits

import software.amazon.smithy.model.node.Node
import software.amazon.smithy.model.shapes.ShapeId
import software.amazon.smithy.model.traits.AnnotationTrait

class SyntheticImplDisplayTrait : AnnotationTrait(ID, Node.objectNode()) {
companion object {
val ID: ShapeId = ShapeId.from("software.amazon.smithy.rust.codegen.core.smithy.traits#syntheticImplDisplayTrait")
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
package software.amazon.smithy.rust.codegen.core.smithy.transformers

import software.amazon.smithy.model.Model
import software.amazon.smithy.model.shapes.AbstractShapeBuilder
import software.amazon.smithy.model.shapes.EnumShape
import software.amazon.smithy.model.shapes.ListShape
import software.amazon.smithy.model.shapes.MapShape
import software.amazon.smithy.model.shapes.StructureShape
import software.amazon.smithy.model.shapes.UnionShape
import software.amazon.smithy.model.traits.ErrorTrait
import software.amazon.smithy.model.transform.ModelTransformer
import software.amazon.smithy.rust.codegen.core.smithy.DirectedWalker
import software.amazon.smithy.rust.codegen.core.smithy.traits.SyntheticImplDisplayTrait
import software.amazon.smithy.rust.codegen.core.util.UNREACHABLE
import software.amazon.smithy.rust.codegen.core.util.getTrait
import software.amazon.smithy.utils.ToSmithyBuilder

/**
* Adds a synthetic trait to shapes that are reachable from error shapes to ensure they
* implement the `Display` trait in generated code.
*
* When a shape is annotated with `@error`, it needs to implement Rust's `Display` trait.
* If the error shape contains references to other structures, those structures also
* need to implement `Display` for proper error formatting.
*/
object AddSyntheticTraitForImplDisplay {
/**
* Transforms the model by adding [SyntheticImplDisplayTrait] to all shapes that are:
* 1. Reachable from an error shape
* 2. Not already marked with `@error`
* 3. Of a type that can implement `Display` (structure, list, union, or map)
*
* @param model The input model to transform
* @return The transformed model with synthetic traits added
*/
fun transform(model: Model): Model {
val walker = DirectedWalker(model)

// Find all error shapes from operations.
val errorShapes =
model.operationShapes
.flatMap { it.errors }
.mapNotNull { model.expectShape(it).asStructureShape().orElse(null) }

// Get shapes reachable from error shapes that need Display impl.
val shapesNeedingDisplay =
errorShapes
.flatMap { walker.walkShapes(it) }
.filter {
(it is StructureShape || it is ListShape || it is UnionShape || it is MapShape || it is EnumShape) &&
it.getTrait<ErrorTrait>() == null
}

// Add synthetic trait to identified shapes.
val transformedShapes =
shapesNeedingDisplay.mapNotNull { shape ->
if (shape !is ToSmithyBuilder<*>) {
UNREACHABLE("Shapes reachable from error shapes should be buildable")
return@mapNotNull null
}

val builder = shape.toBuilder()
if (builder is AbstractShapeBuilder<*, *>) {
builder.addTrait(SyntheticImplDisplayTrait()).build()
} else {
UNREACHABLE("`impl Display` cannot be generated for ${shape.id}")
null
}
}

return ModelTransformer.create().replaceShapes(model, transformedShapes)
}
}
Loading