Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for AssumeRoleWithWebIdentity #178

Merged
merged 9 commits into from
Apr 26, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
203 changes: 170 additions & 33 deletions aws_config_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,8 @@ func TestGetAwsConfig(t *testing.T) {
Description string
EnableEc2MetadataServer bool
EnableEcsCredentialsServer bool
EnableWebIdentityToken bool
EnableWebIdentityEnvVars bool
EnableWebIdentityConfig bool
EnvironmentVariables map[string]string
ExpectedCredentialsValue aws.Credentials
ExpectedRegion string
Expand Down Expand Up @@ -99,7 +100,7 @@ func TestGetAwsConfig(t *testing.T) {
Region: "us-east-1",
SecretKey: servicemocks.MockStaticSecretKey,
},
Description: "config AssumeRoleDurationSeconds",
Description: "config AssumeRoleDuration",
ExpectedCredentialsValue: mockdata.MockStsAssumeRoleCredentials,
ExpectedRegion: "us-east-1",
MockStsEndpoints: []*servicemocks.MockEndpoint{
Expand Down Expand Up @@ -494,7 +495,7 @@ aws_secret_access_key = DefaultSharedCredentialsSecretKey
Region: "us-east-1",
},
Description: "web identity token access key",
EnableWebIdentityToken: true,
EnableWebIdentityEnvVars: true,
ExpectedCredentialsValue: mockdata.MockStsAssumeRoleWithWebIdentityCredentials,
ExpectedRegion: "us-east-1",
MockStsEndpoints: []*servicemocks.MockEndpoint{
Expand Down Expand Up @@ -560,6 +561,42 @@ aws_secret_access_key = DefaultSharedCredentialsSecretKey
servicemocks.MockStsGetCallerIdentityValidEndpoint,
},
},
{
Config: &Config{
AssumeRole: &AssumeRole{
RoleARN: servicemocks.MockStsAssumeRoleArn,
SessionName: servicemocks.MockStsAssumeRoleSessionName,
},
Region: "us-east-1",
},
Description: "AssumeWebIdentity envvar AssumeRoleARN access key",
EnableWebIdentityEnvVars: true,
ExpectedCredentialsValue: mockdata.MockStsAssumeRoleCredentials,
ExpectedRegion: "us-east-1",
MockStsEndpoints: []*servicemocks.MockEndpoint{
servicemocks.MockStsAssumeRoleWithWebIdentityValidEndpoint,
servicemocks.MockStsAssumeRoleValidEndpoint,
servicemocks.MockStsGetCallerIdentityValidEndpoint,
},
},
{
Config: &Config{
AssumeRole: &AssumeRole{
RoleARN: servicemocks.MockStsAssumeRoleArn,
SessionName: servicemocks.MockStsAssumeRoleSessionName,
},
Region: "us-east-1",
},
Description: "AssumeWebIdentity config AssumeRoleARN access key",
EnableWebIdentityConfig: true,
ExpectedCredentialsValue: mockdata.MockStsAssumeRoleCredentials,
ExpectedRegion: "us-east-1",
MockStsEndpoints: []*servicemocks.MockEndpoint{
servicemocks.MockStsAssumeRoleWithWebIdentityValidEndpoint,
servicemocks.MockStsAssumeRoleValidEndpoint,
servicemocks.MockStsGetCallerIdentityValidEndpoint,
},
},
{
Config: &Config{
AccessKey: servicemocks.MockStaticAccessKey,
Expand Down Expand Up @@ -912,9 +949,8 @@ aws_secret_access_key = DefaultSharedCredentialsSecretKey
defer closeEcsCredentials()
}

if testCase.EnableWebIdentityToken {
if testCase.EnableWebIdentityEnvVars || testCase.EnableWebIdentityConfig {
file, err := ioutil.TempFile("", "aws-sdk-go-base-web-identity-token-file")

if err != nil {
t.Fatalf("unexpected error creating temporary web identity token file: %s", err)
}
Expand All @@ -927,9 +963,17 @@ aws_secret_access_key = DefaultSharedCredentialsSecretKey
t.Fatalf("unexpected error writing web identity token file: %s", err)
}

os.Setenv("AWS_ROLE_ARN", servicemocks.MockStsAssumeRoleWithWebIdentityArn)
os.Setenv("AWS_ROLE_SESSION_NAME", servicemocks.MockStsAssumeRoleWithWebIdentitySessionName)
os.Setenv("AWS_WEB_IDENTITY_TOKEN_FILE", file.Name())
if testCase.EnableWebIdentityEnvVars {
os.Setenv("AWS_ROLE_ARN", servicemocks.MockStsAssumeRoleWithWebIdentityArn)
os.Setenv("AWS_ROLE_SESSION_NAME", servicemocks.MockStsAssumeRoleWithWebIdentitySessionName)
os.Setenv("AWS_WEB_IDENTITY_TOKEN_FILE", file.Name())
} else if testCase.EnableWebIdentityConfig {
testCase.Config.AssumeRoleWithWebIdentity = &AssumeRoleWithWebIdentity{
RoleARN: servicemocks.MockStsAssumeRoleWithWebIdentityArn,
SessionName: servicemocks.MockStsAssumeRoleWithWebIdentitySessionName,
WebIdentityTokenFile: file.Name(),
}
}
}

closeSts, _, stsEndpoint := mockdata.GetMockedAwsApiSession("STS", testCase.MockStsEndpoints)
Expand Down Expand Up @@ -2288,21 +2332,56 @@ func TestAssumeRoleWithWebIdentity(t *testing.T) {
testCases := map[string]struct {
Config *Config
SetConfig bool
ExpandEnvVars bool
EnvironmentVariables map[string]string
SetEnvironmentVariable bool
SharedConfigurationFile string
SetSharedConfigurationFile bool
ExpectedCredentialsValue aws.Credentials
MockStsEndpoints []*servicemocks.MockEndpoint
}{
// "config": {
// Config: &Config{},
// SetConfig: true,
// ExpectedCredentialsValue: mockdata.MockStsAssumeRoleWithWebIdentityCredentials,
// MockStsEndpoints: []*servicemocks.MockEndpoint{
// servicemocks.MockStsAssumeRoleWithWebIdentityValidEndpoint,
// },
// },
"config with inline token": {
Config: &Config{
AssumeRoleWithWebIdentity: &AssumeRoleWithWebIdentity{
RoleARN: servicemocks.MockStsAssumeRoleWithWebIdentityArn,
SessionName: servicemocks.MockStsAssumeRoleWithWebIdentitySessionName,
WebIdentityToken: servicemocks.MockWebIdentityToken,
},
},
ExpectedCredentialsValue: mockdata.MockStsAssumeRoleWithWebIdentityCredentials,
MockStsEndpoints: []*servicemocks.MockEndpoint{
servicemocks.MockStsAssumeRoleWithWebIdentityValidEndpoint,
},
},

"config with token file": {
Config: &Config{
AssumeRoleWithWebIdentity: &AssumeRoleWithWebIdentity{
RoleARN: servicemocks.MockStsAssumeRoleWithWebIdentityArn,
SessionName: servicemocks.MockStsAssumeRoleWithWebIdentitySessionName,
},
},
SetConfig: true,
ExpectedCredentialsValue: mockdata.MockStsAssumeRoleWithWebIdentityCredentials,
MockStsEndpoints: []*servicemocks.MockEndpoint{
servicemocks.MockStsAssumeRoleWithWebIdentityValidEndpoint,
},
},

"config with expanded path": {
Config: &Config{
AssumeRoleWithWebIdentity: &AssumeRoleWithWebIdentity{
RoleARN: servicemocks.MockStsAssumeRoleWithWebIdentityArn,
SessionName: servicemocks.MockStsAssumeRoleWithWebIdentitySessionName,
},
},
SetConfig: true,
ExpandEnvVars: true,
ExpectedCredentialsValue: mockdata.MockStsAssumeRoleWithWebIdentityCredentials,
MockStsEndpoints: []*servicemocks.MockEndpoint{
servicemocks.MockStsAssumeRoleWithWebIdentityValidEndpoint,
},
},

"envvar": {
Config: &Config{},
Expand Down Expand Up @@ -2331,19 +2410,24 @@ role_session_name = %[2]s
},
},

// "config overrides envvar": {
// Config: &Config{},
// SetConfig: true,
// EnvironmentVariables: map[string]string{
// "AWS_ROLE_ARN": servicemocks.MockStsAssumeRoleWithWebIdentityArn,
// "AWS_ROLE_SESSION_NAME": servicemocks.MockStsAssumeRoleWithWebIdentitySessionName,
// "AWS_WEB_IDENTITY_TOKEN_FILE": "no-such-file",
// },
// ExpectedCredentialsValue: mockdata.MockStsAssumeRoleWithWebIdentityCredentials,
// MockStsEndpoints: []*servicemocks.MockEndpoint{
// servicemocks.MockStsAssumeRoleWithWebIdentityValidEndpoint,
// },
// },
"config overrides envvar": {
Config: &Config{
AssumeRoleWithWebIdentity: &AssumeRoleWithWebIdentity{
RoleARN: servicemocks.MockStsAssumeRoleWithWebIdentityArn,
SessionName: servicemocks.MockStsAssumeRoleWithWebIdentitySessionName,
WebIdentityToken: servicemocks.MockWebIdentityToken,
},
},
EnvironmentVariables: map[string]string{
"AWS_ROLE_ARN": servicemocks.MockStsAssumeRoleWithWebIdentityArn,
"AWS_ROLE_SESSION_NAME": servicemocks.MockStsAssumeRoleWithWebIdentitySessionName,
"AWS_WEB_IDENTITY_TOKEN_FILE": "no-such-file",
},
ExpectedCredentialsValue: mockdata.MockStsAssumeRoleWithWebIdentityCredentials,
MockStsEndpoints: []*servicemocks.MockEndpoint{
servicemocks.MockStsAssumeRoleWithWebIdentityValidEndpoint,
},
},

"envvar overrides shared configuration": {
Config: &Config{},
Expand All @@ -2363,6 +2447,36 @@ web_identity_token_file = no-such-file
servicemocks.MockStsAssumeRoleWithWebIdentityValidEndpoint,
},
},

"with duration": {
Config: &Config{
AssumeRoleWithWebIdentity: &AssumeRoleWithWebIdentity{
RoleARN: servicemocks.MockStsAssumeRoleWithWebIdentityArn,
SessionName: servicemocks.MockStsAssumeRoleWithWebIdentitySessionName,
WebIdentityToken: servicemocks.MockWebIdentityToken,
Duration: 1 * time.Hour,
},
},
ExpectedCredentialsValue: mockdata.MockStsAssumeRoleWithWebIdentityCredentials,
MockStsEndpoints: []*servicemocks.MockEndpoint{
servicemocks.MockStsAssumeRoleWithWebIdentityValidWithOptions(map[string]string{"DurationSeconds": "3600"}),
},
},

"with policy": {
Config: &Config{
AssumeRoleWithWebIdentity: &AssumeRoleWithWebIdentity{
RoleARN: servicemocks.MockStsAssumeRoleWithWebIdentityArn,
SessionName: servicemocks.MockStsAssumeRoleWithWebIdentitySessionName,
WebIdentityToken: servicemocks.MockWebIdentityToken,
Policy: "{}",
},
},
ExpectedCredentialsValue: mockdata.MockStsAssumeRoleWithWebIdentityCredentials,
MockStsEndpoints: []*servicemocks.MockEndpoint{
servicemocks.MockStsAssumeRoleWithWebIdentityValidWithOptions(map[string]string{"Policy": "{}"}),
},
},
}

for testName, testCase := range testCases {
Expand All @@ -2381,21 +2495,44 @@ web_identity_token_file = no-such-file

testCase.Config.StsEndpoint = stsEndpoint

tempdir, err := ioutil.TempDir("", "temp")
if err != nil {
t.Fatalf("error creating temp dir: %s", err)
}
defer os.Remove(tempdir)
os.Setenv("TMPDIR", tempdir)

tokenFile, err := ioutil.TempFile("", "aws-sdk-go-base-web-identity-token-file")
if err != nil {
t.Fatalf("unexpected error creating temporary web identity token file: %s", err)
}
tokenFileName := tokenFile.Name()

defer os.Remove(tokenFile.Name())
defer os.Remove(tokenFileName)

err = ioutil.WriteFile(tokenFile.Name(), []byte(servicemocks.MockWebIdentityToken), 0600)
err = ioutil.WriteFile(tokenFileName, []byte(servicemocks.MockWebIdentityToken), 0600)

if err != nil {
t.Fatalf("unexpected error writing web identity token file: %s", err)
}

if testCase.ExpandEnvVars {
tmpdir := os.Getenv("TMPDIR")
rel, err := filepath.Rel(tmpdir, tokenFileName)
if err != nil {
t.Fatalf("error making path relative: %s", err)
}
t.Logf("relative: %s", rel)
tokenFileName = filepath.Join("$TMPDIR", rel)
t.Logf("env tempfile: %s", tokenFileName)
}

if testCase.SetConfig {
testCase.Config.AssumeRoleWithWebIdentity.WebIdentityTokenFile = tokenFileName
}

if testCase.SetEnvironmentVariable {
os.Setenv("AWS_WEB_IDENTITY_TOKEN_FILE", tokenFile.Name())
os.Setenv("AWS_WEB_IDENTITY_TOKEN_FILE", tokenFileName)
}

if testCase.SharedConfigurationFile != "" {
Expand All @@ -2408,7 +2545,7 @@ web_identity_token_file = no-such-file
defer os.Remove(file.Name())

if testCase.SetSharedConfigurationFile {
testCase.SharedConfigurationFile += fmt.Sprintf("web_identity_token_file = %s\n", tokenFile.Name())
testCase.SharedConfigurationFile += fmt.Sprintf("web_identity_token_file = %s\n", tokenFileName)
}

err = ioutil.WriteFile(file.Name(), []byte(testCase.SharedConfigurationFile), 0600)
Expand Down
2 changes: 2 additions & 0 deletions config.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ type APNInfo = config.APNInfo

type AssumeRole = config.AssumeRole

type AssumeRoleWithWebIdentity = config.AssumeRoleWithWebIdentity

type UserAgentProducts = config.UserAgentProducts

type UserAgentProduct = config.UserAgentProduct
Expand Down
60 changes: 50 additions & 10 deletions credentials.go
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,19 @@ func getCredentialsProvider(ctx context.Context, c *Config) (aws.CredentialsProv
return nil, "", fmt.Errorf("loading configuration: %w", err)
}

// This can probably be configured directly in commonLoadOptions() once
// https://github.com/aws/aws-sdk-go-v2/pull/1682 is merged
if c.AssumeRoleWithWebIdentity != nil {
if c.AssumeRoleWithWebIdentity.WebIdentityToken == "" && c.AssumeRoleWithWebIdentity.WebIdentityTokenFile == "" {
return nil, "", c.NewCannotAssumeRoleWithWebIdentityError(fmt.Errorf("one of: WebIdentityToken, WebIdentityTokenFile must be set"))
}
provider, err := webIdentityCredentialsProvider(ctx, cfg, c)
if err != nil {
return nil, "", err
}
cfg.Credentials = provider
}

creds, err := cfg.Credentials.Retrieve(ctx)
if err != nil {
if c.Profile != "" && os.Getenv("AWS_ACCESS_KEY_ID") != "" && os.Getenv("AWS_SECRET_ACCESS_KEY") != "" {
Expand All @@ -153,6 +166,30 @@ Error: %w`, err)
return provider, creds.Source, err
}

func webIdentityCredentialsProvider(ctx context.Context, awsConfig aws.Config, c *Config) (aws.CredentialsProvider, error) {
ar := c.AssumeRoleWithWebIdentity
client := stsClient(awsConfig, c)

appCreds := stscreds.NewWebIdentityRoleProvider(client, ar.RoleARN, ar, func(opts *stscreds.WebIdentityRoleOptions) {
opts.RoleSessionName = ar.SessionName
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FYI stscreds.WebIdentityRoleProvider doesn't currently support the Policy field from sts.AssumeRoleWithWebIdentityInput. I've created aws/aws-sdk-go-v2#1662 and may create a PR in the SDK

opts.Duration = ar.Duration

if ar.Policy != "" {
opts.Policy = aws.String(ar.Policy)
}

if len(ar.PolicyARNs) > 0 {
opts.PolicyARNs = getPolicyDescriptorTypes(ar.PolicyARNs)
}
})

_, err := appCreds.Retrieve(ctx)
if err != nil {
return nil, c.NewCannotAssumeRoleWithWebIdentityError(err)
}
return aws.NewCredentialsCache(appCreds), nil
}

func assumeRoleCredentialsProvider(ctx context.Context, awsConfig aws.Config, c *Config) (aws.CredentialsProvider, error) {
ar := c.AssumeRole
// When assuming a role, we need to first authenticate the base credentials above, then assume the desired role
Expand All @@ -173,16 +210,7 @@ func assumeRoleCredentialsProvider(ctx context.Context, awsConfig aws.Config, c
}

if len(ar.PolicyARNs) > 0 {
var policyDescriptorTypes []types.PolicyDescriptorType

for _, policyARN := range ar.PolicyARNs {
policyDescriptorType := types.PolicyDescriptorType{
Arn: aws.String(policyARN),
}
policyDescriptorTypes = append(policyDescriptorTypes, policyDescriptorType)
}

opts.PolicyARNs = policyDescriptorTypes
opts.PolicyARNs = getPolicyDescriptorTypes(ar.PolicyARNs)
}

if len(ar.Tags) > 0 {
Expand All @@ -208,3 +236,15 @@ func assumeRoleCredentialsProvider(ctx context.Context, awsConfig aws.Config, c
}
return aws.NewCredentialsCache(appCreds), nil
}

func getPolicyDescriptorTypes(policyARNs []string) []types.PolicyDescriptorType {
var policyDescriptorTypes []types.PolicyDescriptorType

for _, policyARN := range policyARNs {
policyDescriptorType := types.PolicyDescriptorType{
Arn: aws.String(policyARN),
}
policyDescriptorTypes = append(policyDescriptorTypes, policyDescriptorType)
}
return policyDescriptorTypes
}
Loading