Skip to content

Commit a8ecbaa

Browse files
committed
add credential cache
1 parent 976f893 commit a8ecbaa

File tree

4 files changed

+353
-0
lines changed

4 files changed

+353
-0
lines changed

azure/credential_cache.go

Lines changed: 159 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,159 @@
1+
/*
2+
Copyright 2024 The Kubernetes Authors.
3+
4+
Licensed under the Apache License, Version 2.0 (the "License");
5+
you may not use this file except in compliance with the License.
6+
You may obtain a copy of the License at
7+
8+
http://www.apache.org/licenses/LICENSE-2.0
9+
10+
Unless required by applicable law or agreed to in writing, software
11+
distributed under the License is distributed on an "AS IS" BASIS,
12+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
See the License for the specific language governing permissions and
14+
limitations under the License.
15+
*/
16+
17+
package azure
18+
19+
import (
20+
"sync"
21+
22+
"github.com/Azure/azure-sdk-for-go/sdk/azcore"
23+
"github.com/Azure/azure-sdk-for-go/sdk/azidentity"
24+
)
25+
26+
type credentialCache struct {
27+
mut *sync.Mutex
28+
cache map[credentialCacheKey]azcore.TokenCredential
29+
credFactory credentialFactory
30+
}
31+
32+
type credentialFactory interface {
33+
newClientSecretCredential(tenantID string, clientID string, clientSecret string, opts *azidentity.ClientSecretCredentialOptions) (azcore.TokenCredential, error)
34+
newClientCertificateCredential(tenantID string, clientID string, clientCertificate []byte, clientCertificatePassword []byte, opts *azidentity.ClientCertificateCredentialOptions) (azcore.TokenCredential, error)
35+
newManagedIdentityCredential(opts *azidentity.ManagedIdentityCredentialOptions) (azcore.TokenCredential, error)
36+
newWorkloadIdentityCredential(opts *azidentity.WorkloadIdentityCredentialOptions) (azcore.TokenCredential, error)
37+
}
38+
39+
// CredentialType represents the auth mechanism in use.
40+
type CredentialType int
41+
42+
const (
43+
// CredentialTypeClientSecret is for Service Principals with Client Secrets.
44+
CredentialTypeClientSecret CredentialType = iota
45+
// CredentialTypeClientCert is for Service Principals with Client certificates.
46+
CredentialTypeClientCert
47+
// CredentialTypeManagedIdentity is for Managed Identities.
48+
CredentialTypeManagedIdentity
49+
// CredentialTypeWorkloadIdentity is for Workload Identity.
50+
CredentialTypeWorkloadIdentity
51+
)
52+
53+
type credentialCacheKey struct {
54+
authorityHost string
55+
credentialType CredentialType
56+
tenantID string
57+
clientID string
58+
}
59+
60+
// NewCredentialCache creates a new, empty CredentialCache.
61+
func NewCredentialCache() CredentialCache {
62+
return &credentialCache{
63+
mut: new(sync.Mutex),
64+
cache: make(map[credentialCacheKey]azcore.TokenCredential),
65+
credFactory: azureCredentialFactory{},
66+
}
67+
}
68+
69+
func (c *credentialCache) GetOrStoreClientSecret(tenantID, clientID, clientSecret string, opts *azidentity.ClientSecretCredentialOptions) (azcore.TokenCredential, error) {
70+
return c.getOrStore(
71+
credentialCacheKey{
72+
authorityHost: opts.Cloud.ActiveDirectoryAuthorityHost,
73+
credentialType: CredentialTypeClientSecret,
74+
tenantID: tenantID,
75+
clientID: clientID,
76+
},
77+
func() (azcore.TokenCredential, error) {
78+
return c.credFactory.newClientSecretCredential(tenantID, clientID, clientSecret, opts)
79+
},
80+
)
81+
}
82+
83+
func (c *credentialCache) GetOrStoreClientCert(tenantID, clientID string, cert, certPassword []byte, opts *azidentity.ClientCertificateCredentialOptions) (azcore.TokenCredential, error) {
84+
return c.getOrStore(
85+
credentialCacheKey{
86+
authorityHost: opts.Cloud.ActiveDirectoryAuthorityHost,
87+
credentialType: CredentialTypeClientCert,
88+
tenantID: tenantID,
89+
clientID: clientID,
90+
},
91+
func() (azcore.TokenCredential, error) {
92+
return c.credFactory.newClientCertificateCredential(tenantID, clientID, cert, certPassword, opts)
93+
},
94+
)
95+
}
96+
97+
func (c *credentialCache) GetOrStoreManagedIdentity(opts *azidentity.ManagedIdentityCredentialOptions) (azcore.TokenCredential, error) {
98+
return c.getOrStore(
99+
credentialCacheKey{
100+
authorityHost: opts.Cloud.ActiveDirectoryAuthorityHost,
101+
credentialType: CredentialTypeManagedIdentity,
102+
// tenantID not used for managed identity
103+
clientID: opts.ID.String(),
104+
},
105+
func() (azcore.TokenCredential, error) {
106+
return c.credFactory.newManagedIdentityCredential(opts)
107+
},
108+
)
109+
}
110+
111+
func (c *credentialCache) GetOrStoreWorkloadIdentity(opts *azidentity.WorkloadIdentityCredentialOptions) (azcore.TokenCredential, error) {
112+
return c.getOrStore(
113+
credentialCacheKey{
114+
authorityHost: opts.Cloud.ActiveDirectoryAuthorityHost,
115+
credentialType: CredentialTypeWorkloadIdentity,
116+
tenantID: opts.TenantID,
117+
clientID: opts.ClientID,
118+
},
119+
func() (azcore.TokenCredential, error) {
120+
return c.credFactory.newWorkloadIdentityCredential(opts)
121+
},
122+
)
123+
}
124+
125+
func (c *credentialCache) getOrStore(key credentialCacheKey, newCredFunc func() (azcore.TokenCredential, error)) (azcore.TokenCredential, error) {
126+
c.mut.Lock()
127+
defer c.mut.Unlock()
128+
if cred, exists := c.cache[key]; exists {
129+
return cred, nil
130+
}
131+
cred, err := newCredFunc()
132+
if err != nil {
133+
return nil, err
134+
}
135+
c.cache[key] = cred
136+
return cred, nil
137+
}
138+
139+
type azureCredentialFactory struct{}
140+
141+
func (azureCredentialFactory) newClientSecretCredential(tenantID string, clientID string, clientSecret string, opts *azidentity.ClientSecretCredentialOptions) (azcore.TokenCredential, error) {
142+
return azidentity.NewClientSecretCredential(tenantID, clientID, clientSecret, opts)
143+
}
144+
145+
func (azureCredentialFactory) newClientCertificateCredential(tenantID string, clientID string, clientCertificate []byte, clientCertificatePassword []byte, opts *azidentity.ClientCertificateCredentialOptions) (azcore.TokenCredential, error) {
146+
certs, certKey, err := azidentity.ParseCertificates(clientCertificate, clientCertificatePassword)
147+
if err != nil {
148+
return nil, err
149+
}
150+
return azidentity.NewClientCertificateCredential(tenantID, clientID, certs, certKey, opts)
151+
}
152+
153+
func (azureCredentialFactory) newManagedIdentityCredential(opts *azidentity.ManagedIdentityCredentialOptions) (azcore.TokenCredential, error) {
154+
return azidentity.NewManagedIdentityCredential(opts)
155+
}
156+
157+
func (azureCredentialFactory) newWorkloadIdentityCredential(opts *azidentity.WorkloadIdentityCredentialOptions) (azcore.TokenCredential, error) {
158+
return azidentity.NewWorkloadIdentityCredential(opts)
159+
}

azure/credential_cache_test.go

Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,101 @@
1+
/*
2+
Copyright 2024 The Kubernetes Authors.
3+
4+
Licensed under the Apache License, Version 2.0 (the "License");
5+
you may not use this file except in compliance with the License.
6+
You may obtain a copy of the License at
7+
8+
http://www.apache.org/licenses/LICENSE-2.0
9+
10+
Unless required by applicable law or agreed to in writing, software
11+
distributed under the License is distributed on an "AS IS" BASIS,
12+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
See the License for the specific language governing permissions and
14+
limitations under the License.
15+
*/
16+
17+
package azure
18+
19+
import (
20+
"context"
21+
"strconv"
22+
"sync"
23+
"testing"
24+
25+
"github.com/Azure/azure-sdk-for-go/sdk/azcore"
26+
"github.com/Azure/azure-sdk-for-go/sdk/azcore/policy"
27+
. "github.com/onsi/gomega"
28+
"github.com/pkg/errors"
29+
)
30+
31+
type fakeTokenCredential struct {
32+
tenantID string
33+
}
34+
35+
func (t fakeTokenCredential) GetToken(ctx context.Context, options policy.TokenRequestOptions) (azcore.AccessToken, error) {
36+
return azcore.AccessToken{}, nil
37+
}
38+
39+
func TestGetOrStore(t *testing.T) {
40+
g := NewGomegaWithT(t)
41+
42+
credCache := &credentialCache{
43+
mut: new(sync.Mutex),
44+
cache: make(map[credentialCacheKey]azcore.TokenCredential),
45+
}
46+
47+
newCredCount := 0
48+
newCredFunc := func(cred fakeTokenCredential, err error) func() (azcore.TokenCredential, error) {
49+
return func() (azcore.TokenCredential, error) {
50+
newCredCount++
51+
return cred, err
52+
}
53+
}
54+
55+
// the first call for a new key should invoke newCredFunc
56+
cred, err := credCache.getOrStore(credentialCacheKey{tenantID: "1"}, newCredFunc(fakeTokenCredential{tenantID: "1"}, nil))
57+
g.Expect(err).NotTo(HaveOccurred())
58+
g.Expect(cred).To(Equal(fakeTokenCredential{tenantID: "1"}))
59+
g.Expect(newCredCount).To(Equal(1))
60+
61+
// subsequent calls for the same key should not create a new credential
62+
cred, err = credCache.getOrStore(credentialCacheKey{tenantID: "1"}, newCredFunc(fakeTokenCredential{tenantID: "1"}, nil))
63+
g.Expect(err).NotTo(HaveOccurred())
64+
g.Expect(cred).To(Equal(fakeTokenCredential{tenantID: "1"}))
65+
g.Expect(newCredCount).To(Equal(1))
66+
cred, err = credCache.getOrStore(credentialCacheKey{tenantID: "1"}, newCredFunc(fakeTokenCredential{tenantID: "1"}, nil))
67+
g.Expect(err).NotTo(HaveOccurred())
68+
g.Expect(cred).To(Equal(fakeTokenCredential{tenantID: "1"}))
69+
g.Expect(newCredCount).To(Equal(1))
70+
71+
expectedErr := errors.New("an error")
72+
cred, err = credCache.getOrStore(credentialCacheKey{tenantID: "2"}, newCredFunc(fakeTokenCredential{tenantID: "2"}, expectedErr))
73+
g.Expect(err).To(MatchError(expectedErr))
74+
g.Expect(cred).To(BeNil())
75+
g.Expect(newCredCount).To(Equal(2))
76+
}
77+
78+
func TestGetOrStoreRace(t *testing.T) {
79+
// This test makes no assertions, it only fails when the race detector finds race conditions.
80+
81+
credCache := &credentialCache{
82+
mut: new(sync.Mutex),
83+
cache: make(map[credentialCacheKey]azcore.TokenCredential),
84+
}
85+
newCredFunc := func(cred fakeTokenCredential, err error) func() (azcore.TokenCredential, error) {
86+
return func() (azcore.TokenCredential, error) {
87+
return cred, err
88+
}
89+
}
90+
91+
wg := new(sync.WaitGroup)
92+
n := 1000
93+
for i := 0; i < n; i++ {
94+
wg.Add(1)
95+
go func() {
96+
defer wg.Done()
97+
_, _ = credCache.getOrStore(credentialCacheKey{tenantID: strconv.Itoa(i % 100)}, newCredFunc(fakeTokenCredential{}, nil))
98+
}()
99+
}
100+
wg.Wait()
101+
}

azure/interfaces.go

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ import (
2121
"time"
2222

2323
"github.com/Azure/azure-sdk-for-go/sdk/azcore"
24+
"github.com/Azure/azure-sdk-for-go/sdk/azidentity"
2425
"github.com/Azure/azure-service-operator/v2/pkg/genruntime"
2526
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
2627
clusterv1 "sigs.k8s.io/cluster-api/api/v1beta1"
@@ -164,3 +165,11 @@ type ASOResourceSpecGetter[T genruntime.MetaObject] interface {
164165
// non-ASO-backed CAPZ and should be considered eligible for adoption.
165166
WasManaged(T) bool
166167
}
168+
169+
// CredentialCache caches azcore.TokenCredentials.
170+
type CredentialCache interface {
171+
GetOrStoreClientSecret(tenantID, clientID, clientSecret string, opts *azidentity.ClientSecretCredentialOptions) (azcore.TokenCredential, error)
172+
GetOrStoreClientCert(tenantID, clientID string, cert, certPassword []byte, opts *azidentity.ClientCertificateCredentialOptions) (azcore.TokenCredential, error)
173+
GetOrStoreManagedIdentity(opts *azidentity.ManagedIdentityCredentialOptions) (azcore.TokenCredential, error)
174+
GetOrStoreWorkloadIdentity(opts *azidentity.WorkloadIdentityCredentialOptions) (azcore.TokenCredential, error)
175+
}

azure/mock_azure/azure_mock.go

Lines changed: 84 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)