Skip to content

Resource Identity: Adds attribute name mapping for Framework resource types #43496

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Jul 23, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 14 additions & 17 deletions internal/provider/framework/identity_interceptor.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,14 @@ import (
"github.com/hashicorp/terraform-plugin-framework/diag"
"github.com/hashicorp/terraform-plugin-framework/path"
"github.com/hashicorp/terraform-plugin-framework/resource"
tfslices "github.com/hashicorp/terraform-provider-aws/internal/slices"
inttypes "github.com/hashicorp/terraform-provider-aws/internal/types"
"github.com/hashicorp/terraform-provider-aws/names"
)

var _ resourceCRUDInterceptor = identityInterceptor{}

type identityInterceptor struct {
attributes []string
attributes []inttypes.IdentityAttribute
}

func (r identityInterceptor) create(ctx context.Context, opts interceptorOptions[resource.CreateRequest, resource.CreateResponse]) diag.Diagnostics {
Expand All @@ -32,28 +31,28 @@ func (r identityInterceptor) create(ctx context.Context, opts interceptorOptions
break
}

for _, attrName := range r.attributes {
switch attrName {
for _, att := range r.attributes {
switch att.Name() {
case names.AttrAccountID:
diags.Append(identity.SetAttribute(ctx, path.Root(attrName), awsClient.AccountID(ctx))...)
diags.Append(identity.SetAttribute(ctx, path.Root(att.Name()), awsClient.AccountID(ctx))...)
if diags.HasError() {
return diags
}

case names.AttrRegion:
diags.Append(identity.SetAttribute(ctx, path.Root(attrName), awsClient.Region(ctx))...)
diags.Append(identity.SetAttribute(ctx, path.Root(att.Name()), awsClient.Region(ctx))...)
if diags.HasError() {
return diags
}

default:
var attrVal attr.Value
diags.Append(response.State.GetAttribute(ctx, path.Root(attrName), &attrVal)...)
diags.Append(response.State.GetAttribute(ctx, path.Root(att.ResourceAttributeName()), &attrVal)...)
if diags.HasError() {
return diags
}

diags.Append(identity.SetAttribute(ctx, path.Root(attrName), attrVal)...)
diags.Append(identity.SetAttribute(ctx, path.Root(att.Name()), attrVal)...)
if diags.HasError() {
return diags
}
Expand All @@ -78,28 +77,28 @@ func (r identityInterceptor) read(ctx context.Context, opts interceptorOptions[r
break
}

for _, attrName := range r.attributes {
switch attrName {
for _, att := range r.attributes {
switch att.Name() {
case names.AttrAccountID:
diags.Append(identity.SetAttribute(ctx, path.Root(attrName), awsClient.AccountID(ctx))...)
diags.Append(identity.SetAttribute(ctx, path.Root(att.Name()), awsClient.AccountID(ctx))...)
if diags.HasError() {
return diags
}

case names.AttrRegion:
diags.Append(identity.SetAttribute(ctx, path.Root(attrName), awsClient.Region(ctx))...)
diags.Append(identity.SetAttribute(ctx, path.Root(att.Name()), awsClient.Region(ctx))...)
if diags.HasError() {
return diags
}

default:
var attrVal attr.Value
diags.Append(response.State.GetAttribute(ctx, path.Root(attrName), &attrVal)...)
diags.Append(response.State.GetAttribute(ctx, path.Root(att.ResourceAttributeName()), &attrVal)...)
if diags.HasError() {
return diags
}

diags.Append(identity.SetAttribute(ctx, path.Root(attrName), attrVal)...)
diags.Append(identity.SetAttribute(ctx, path.Root(att.Name()), attrVal)...)
if diags.HasError() {
return diags
}
Expand All @@ -122,8 +121,6 @@ func (r identityInterceptor) delete(ctx context.Context, opts interceptorOptions

func newIdentityInterceptor(attributes []inttypes.IdentityAttribute) identityInterceptor {
return identityInterceptor{
attributes: tfslices.ApplyToAll(attributes, func(v inttypes.IdentityAttribute) string {
return v.Name()
}),
attributes: attributes,
}
}
109 changes: 91 additions & 18 deletions internal/provider/framework/identity_interceptor_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"testing"

"github.com/aws/aws-sdk-go-v2/aws"
"github.com/hashicorp/terraform-plugin-framework/diag"
"github.com/hashicorp/terraform-plugin-framework/path"
"github.com/hashicorp/terraform-plugin-framework/resource"
"github.com/hashicorp/terraform-plugin-framework/resource/identityschema"
Expand All @@ -26,8 +27,6 @@ import (
func TestIdentityInterceptor(t *testing.T) {
t.Parallel()

ctx := t.Context()

accountID := "123456789012"
region := "us-west-2"
name := "a_name"
Expand All @@ -44,24 +43,80 @@ func TestIdentityInterceptor(t *testing.T) {
},
}

client := mockClient{
accountID: accountID,
region: region,
}

stateAttrs := map[string]string{
"name": name,
"region": region,
"type": "some_type",
}

identitySpec := regionalSingleParameterIdentitySpec("name")
identitySchema := identity.NewIdentitySchema(identitySpec)
testOperations := map[string]struct {
operation func(ctx context.Context, interceptor identityInterceptor, resourceSchema schema.Schema, stateAttrs map[string]string, identity *tfsdk.ResourceIdentity, client awsClient) (*tfsdk.ResourceIdentity, diag.Diagnostics)
}{
"create": {
operation: create,
},
"read": {
operation: read,
},
}

interceptor := newIdentityInterceptor(identitySpec.Attributes)
for tname, tc := range testOperations {
t.Run(tname, func(t *testing.T) {
t.Parallel()

client := mockClient{
accountID: accountID,
region: region,
}
operation := tc.operation

testCases := map[string]struct {
attrName string
identitySpec inttypes.Identity
}{
"same names": {
attrName: "name",
identitySpec: regionalSingleParameterIdentitySpec("name"),
},
"name mapped": {
attrName: "resource_name",
identitySpec: regionalSingleParameterIdentitySpecNameMapped("resource_name", "name"),
},
}

for tname, tc := range testCases {
t.Run(tname, func(t *testing.T) {
t.Parallel()
ctx := t.Context()

identitySchema := identity.NewIdentitySchema(tc.identitySpec)

interceptor := newIdentityInterceptor(tc.identitySpec.Attributes)

identity := emtpyIdentityFromSchema(ctx, &identitySchema)
identity := emtpyIdentityFromSchema(ctx, &identitySchema)

responseIdentity, diags := operation(ctx, interceptor, resourceSchema, stateAttrs, identity, client)
if len(diags) > 0 {
t.Fatalf("unexpected diags during interception: %s", diags)
}

if e, a := accountID, getIdentityAttributeValue(ctx, t, responseIdentity, path.Root("account_id")); e != a {
t.Errorf("expected Identity `account_id` to be %q, got %q", e, a)
}
if e, a := region, getIdentityAttributeValue(ctx, t, responseIdentity, path.Root("region")); e != a {
t.Errorf("expected Identity `region` to be %q, got %q", e, a)
}
if e, a := name, getIdentityAttributeValue(ctx, t, responseIdentity, path.Root(tc.attrName)); e != a {
t.Errorf("expected Identity `%s` to be %q, got %q", tc.attrName, e, a)
}
})
}
})
}
}

func create(ctx context.Context, interceptor identityInterceptor, resourceSchema schema.Schema, stateAttrs map[string]string, identity *tfsdk.ResourceIdentity, client awsClient) (*tfsdk.ResourceIdentity, diag.Diagnostics) {
request := resource.CreateRequest{
Config: configFromSchema(ctx, resourceSchema, stateAttrs),
Plan: planFromSchema(ctx, resourceSchema, stateAttrs),
Expand All @@ -79,19 +134,33 @@ func TestIdentityInterceptor(t *testing.T) {
}

diags := interceptor.create(ctx, opts)
if len(diags) > 0 {
t.Fatalf("unexpected diags during interception: %s", diags)
if diags.HasError() {
return nil, diags
}
return response.Identity, diags
}

if e, a := accountID, getIdentityAttributeValue(ctx, t, response.Identity, path.Root("account_id")); e != a {
t.Errorf("expected Identity `account_id` to be %q, got %q", e, a)
func read(ctx context.Context, interceptor identityInterceptor, resourceSchema schema.Schema, stateAttrs map[string]string, identity *tfsdk.ResourceIdentity, client awsClient) (*tfsdk.ResourceIdentity, diag.Diagnostics) {
request := resource.ReadRequest{
State: stateFromSchema(ctx, resourceSchema, stateAttrs),
Identity: identity,
}
if e, a := region, getIdentityAttributeValue(ctx, t, response.Identity, path.Root("region")); e != a {
t.Errorf("expected Identity `region` to be %q, got %q", e, a)
response := resource.ReadResponse{
State: stateFromSchema(ctx, resourceSchema, stateAttrs),
Identity: identity,
}
if e, a := name, getIdentityAttributeValue(ctx, t, response.Identity, path.Root("name")); e != a {
t.Errorf("expected Identity `name` to be %q, got %q", e, a)
opts := interceptorOptions[resource.ReadRequest, resource.ReadResponse]{
c: client,
request: &request,
response: &response,
when: After,
}

diags := interceptor.read(ctx, opts)
if diags.HasError() {
return nil, diags
}
return response.Identity, diags
}

func getIdentityAttributeValue(ctx context.Context, t *testing.T, identity *tfsdk.ResourceIdentity, path path.Path) string {
Expand All @@ -108,6 +177,10 @@ func regionalSingleParameterIdentitySpec(name string) inttypes.Identity {
return inttypes.RegionalSingleParameterIdentity(name)
}

func regionalSingleParameterIdentitySpecNameMapped(identityAttrName, resourceAttrName string) inttypes.Identity {
return inttypes.RegionalSingleParameterIdentityWithMappedName(identityAttrName, resourceAttrName)
}

func stateFromSchema(ctx context.Context, schema schema.Schema, values map[string]string) tfsdk.State {
val := make(map[string]tftypes.Value)
for name := range schema.Attributes {
Expand Down
20 changes: 12 additions & 8 deletions internal/provider/framework/importer/parameterized.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,9 @@ import (
)

func SingleParameterized(ctx context.Context, client AWSClient, request resource.ImportStateRequest, identitySpec *inttypes.Identity, importSpec *inttypes.FrameworkImport, response *resource.ImportStateResponse) {
attrPath := path.Root(identitySpec.IdentityAttribute)
attr := identitySpec.Attributes[len(identitySpec.Attributes)-1]
identityPath := path.Root(attr.Name())
resourcePath := path.Root(attr.ResourceAttributeName())

parameterVal := request.ID

Expand All @@ -26,18 +28,18 @@ func SingleParameterized(ctx context.Context, client AWSClient, request resource
}

var parameterAttr types.String
response.Diagnostics.Append(identity.GetAttribute(ctx, attrPath, &parameterAttr)...)
response.Diagnostics.Append(identity.GetAttribute(ctx, identityPath, &parameterAttr)...)
if response.Diagnostics.HasError() {
return
}
parameterVal = parameterAttr.ValueString()
}

response.Diagnostics.Append(response.State.SetAttribute(ctx, attrPath, parameterVal)...)
response.Diagnostics.Append(response.State.SetAttribute(ctx, resourcePath, parameterVal)...)

if identity := response.Identity; identity != nil {
response.Diagnostics.Append(identity.SetAttribute(ctx, path.Root(names.AttrAccountID), client.AccountID(ctx))...)
response.Diagnostics.Append(identity.SetAttribute(ctx, attrPath, parameterVal)...)
response.Diagnostics.Append(identity.SetAttribute(ctx, identityPath, parameterVal)...)
}

if !identitySpec.IsGlobalResource {
Expand Down Expand Up @@ -80,18 +82,20 @@ func MultipleParameterized(ctx context.Context, client AWSClient, request resour
// Do nothing

default:
attrPath := path.Root(attr.Name())
identityPath := path.Root(attr.Name())
resourcePath := path.Root(attr.ResourceAttributeName())

var parameterAttr types.String
response.Diagnostics.Append(identity.GetAttribute(ctx, attrPath, &parameterAttr)...)
response.Diagnostics.Append(identity.GetAttribute(ctx, identityPath, &parameterAttr)...)
if response.Diagnostics.HasError() {
return
}
parameterVal := parameterAttr.ValueString()

response.Diagnostics.Append(response.State.SetAttribute(ctx, attrPath, parameterVal)...)
response.Diagnostics.Append(response.State.SetAttribute(ctx, resourcePath, parameterVal)...)

if identity := response.Identity; identity != nil {
response.Diagnostics.Append(identity.SetAttribute(ctx, attrPath, parameterVal)...)
response.Diagnostics.Append(identity.SetAttribute(ctx, identityPath, parameterVal)...)
}
}
}
Expand Down
Loading
Loading