Skip to content

Commit

Permalink
Synchronize calls to DiscoverPollEndpoint (#4504)
Browse files Browse the repository at this point in the history
* test: reproduce DiscoverPollEndpoint race condition

* add: synchronize calls to DiscoverPollEndpoint

Agents that share ECSClients can call DiscoverPollEndpoint (DPE) multiple
times per task. Each routine that calls DPE will first check the cache
before performing the actual API call over the network. The intention
here is that only one actual API call is performed (by the first
routine to call DPE).

However, it is possible for multiple routines to race and effectively
make many actual API calls. This is because the `pollEndpointCache` is
only updated when the first API call _returns_.

This change enforces the intended behavior by making subsequent
routines wait for the cache to be updated (or not) by the first
thread, eliminating simultaneous calls to DPE.

---------

Co-authored-by: Isaac Feldman <icf@amazon.com>
  • Loading branch information
isaac-400 and Isaac Feldman authored Feb 27, 2025
1 parent b90cd31 commit deb74a5
Show file tree
Hide file tree
Showing 7 changed files with 108 additions and 17 deletions.

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

11 changes: 10 additions & 1 deletion ecs-agent/api/ecs/client/ecs_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ import (
"fmt"
"net/http"
"strings"
"sync"
"time"

"github.com/aws/aws-sdk-go/aws"
Expand Down Expand Up @@ -63,6 +64,9 @@ const (
setInstanceIdRetryBackoffMax = 5 * time.Second
setInstanceIdRetryBackoffJitter = 0.2
setInstanceIdRetryBackoffMultiple = 2
// discoverPollEndpointTimeout is the maximum permitted time a single ECSClient.DiscoverPollEndpoint call can take.
// The SDK client uses the default retryer which gives a max retry count of 3, we combine this with the timeout for the underlying httpclient's RoundtripTimeout.
discoverPollEndpointTimeout = 3 * RoundtripTimeout
// Below constants are used for RegisterContainerInstance retry with exponential backoff when receiving non-terminal errors.
// To ensure parity in all regions and on all launch types, we should not set any time limit on the RCI timeout.
// Thus, setting the max RCI retry timeout allowed to 1 hour, and capping max retry backoff at 192 seconds (3 * 2^6).
Expand All @@ -82,6 +86,7 @@ type ecsClient struct {
ec2metadata ec2.EC2MetadataClient
httpClient *http.Client
pollEndpointCache async.TTLCache
pollEndpointLock sync.Mutex
isFIPSDetected bool
shouldExcludeIPv6PortBinding bool
sascCustomRetryBackoff func(func() error) error
Expand Down Expand Up @@ -783,6 +788,10 @@ func (client *ecsClient) DiscoverSystemLogsEndpoint(containerInstanceArn string,

func (client *ecsClient) discoverPollEndpoint(containerInstanceArn string,
availabilityZone string) (*ecsmodel.DiscoverPollEndpointOutput, error) {
client.pollEndpointLock.Lock()
defer client.pollEndpointLock.Unlock()
ctx, cancel := context.WithTimeout(context.Background(), discoverPollEndpointTimeout)
defer cancel()
// Try getting an entry from the cache.
cachedEndpoint, expired, found := client.pollEndpointCache.Get(containerInstanceArn)
if !expired && found {
Expand Down Expand Up @@ -810,7 +819,7 @@ func (client *ecsClient) discoverPollEndpoint(containerInstanceArn string,
field.ContainerInstanceARN: containerInstanceArn,
field.AvailabilityZone: availabilityZone,
})
output, err := client.standardClient.DiscoverPollEndpoint(&ecsmodel.DiscoverPollEndpointInput{
output, err := client.standardClient.DiscoverPollEndpointWithContext(ctx, &ecsmodel.DiscoverPollEndpointInput{
ContainerInstance: &containerInstanceArn,
Cluster: aws.String(client.configAccessor.Cluster()),
ZoneId: aws.String(availabilityZone),
Expand Down
61 changes: 46 additions & 15 deletions ecs-agent/api/ecs/client/ecs_client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import (
"errors"
"fmt"
"strings"
"sync"
"testing"
"time"

Expand Down Expand Up @@ -838,7 +839,7 @@ func TestDiscoverTelemetryEndpoint(t *testing.T) {

tester := setup(t, ctrl, ec2.NewBlackholeEC2MetadataClient(), nil)
expectedEndpoint := "http://127.0.0.1"
tester.mockStandardClient.EXPECT().DiscoverPollEndpoint(gomock.Any()).
tester.mockStandardClient.EXPECT().DiscoverPollEndpointWithContext(gomock.Any(), gomock.Any(), gomock.Any()).
Return(&ecsmodel.DiscoverPollEndpointOutput{TelemetryEndpoint: &expectedEndpoint}, nil)
endpoint, err := tester.client.DiscoverTelemetryEndpoint(containerInstanceARN)
assert.NoError(t, err, "Error getting telemetry endpoint")
Expand All @@ -851,7 +852,7 @@ func TestDiscoverTelemetryEndpointError(t *testing.T) {

tester := setup(t, ctrl, ec2.NewBlackholeEC2MetadataClient(), nil)

tester.mockStandardClient.EXPECT().DiscoverPollEndpoint(gomock.Any()).Return(nil,
tester.mockStandardClient.EXPECT().DiscoverPollEndpointWithContext(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil,
fmt.Errorf("Error getting endpoint"))
_, err := tester.client.DiscoverTelemetryEndpoint(containerInstanceARN)
assert.ErrorContains(t, err, "Error getting endpoint",
Expand All @@ -864,7 +865,7 @@ func TestDiscoverNilTelemetryEndpoint(t *testing.T) {

tester := setup(t, ctrl, ec2.NewBlackholeEC2MetadataClient(), nil)
pollEndpoint := "http://127.0.0.1"
tester.mockStandardClient.EXPECT().DiscoverPollEndpoint(gomock.Any()).
tester.mockStandardClient.EXPECT().DiscoverPollEndpointWithContext(gomock.Any(), gomock.Any(), gomock.Any()).
Return(&ecsmodel.DiscoverPollEndpointOutput{Endpoint: &pollEndpoint}, nil)
_, err := tester.client.DiscoverTelemetryEndpoint(containerInstanceARN)
assert.ErrorContains(t, err, "no telemetry endpoint returned",
Expand All @@ -877,7 +878,7 @@ func TestDiscoverServiceConnectEndpoint(t *testing.T) {

tester := setup(t, ctrl, ec2.NewBlackholeEC2MetadataClient(), nil)
expectedEndpoint := "http://127.0.0.1"
tester.mockStandardClient.EXPECT().DiscoverPollEndpoint(gomock.Any()).
tester.mockStandardClient.EXPECT().DiscoverPollEndpointWithContext(gomock.Any(), gomock.Any(), gomock.Any()).
Return(&ecsmodel.DiscoverPollEndpointOutput{ServiceConnectEndpoint: &expectedEndpoint}, nil)
endpoint, err := tester.client.DiscoverServiceConnectEndpoint(containerInstanceARN)
assert.NoError(t, err, "Error getting service connect endpoint")
Expand All @@ -889,7 +890,7 @@ func TestDiscoverServiceConnectEndpointError(t *testing.T) {
defer ctrl.Finish()

tester := setup(t, ctrl, ec2.NewBlackholeEC2MetadataClient(), nil)
tester.mockStandardClient.EXPECT().DiscoverPollEndpoint(gomock.Any()).Return(nil,
tester.mockStandardClient.EXPECT().DiscoverPollEndpointWithContext(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil,
fmt.Errorf("Error getting endpoint"))
_, err := tester.client.DiscoverServiceConnectEndpoint(containerInstanceARN)
assert.ErrorContains(t, err, "Error getting endpoint",
Expand All @@ -902,7 +903,7 @@ func TestDiscoverNilServiceConnectEndpoint(t *testing.T) {

tester := setup(t, ctrl, ec2.NewBlackholeEC2MetadataClient(), nil)
pollEndpoint := "http://127.0.0.1"
tester.mockStandardClient.EXPECT().DiscoverPollEndpoint(gomock.Any()).
tester.mockStandardClient.EXPECT().DiscoverPollEndpointWithContext(gomock.Any(), gomock.Any(), gomock.Any()).
Return(&ecsmodel.DiscoverPollEndpointOutput{Endpoint: &pollEndpoint}, nil)
_, err := tester.client.DiscoverServiceConnectEndpoint(containerInstanceARN)
assert.ErrorContains(t, err, "no ServiceConnect endpoint returned",
Expand All @@ -915,7 +916,7 @@ func TestDiscoverSystemLogsEndpoint(t *testing.T) {

tester := setup(t, ctrl, ec2.NewBlackholeEC2MetadataClient(), nil)
expectedEndpoint := "http://127.0.0.1"
tester.mockStandardClient.EXPECT().DiscoverPollEndpoint(gomock.Any()).
tester.mockStandardClient.EXPECT().DiscoverPollEndpointWithContext(gomock.Any(), gomock.Any(), gomock.Any()).
Return(&ecsmodel.DiscoverPollEndpointOutput{SystemLogsEndpoint: &expectedEndpoint}, nil)
endpoint, err := tester.client.DiscoverSystemLogsEndpoint(containerInstanceARN, zoneId)
assert.NoError(t, err, "Error getting system logs endpoint")
Expand All @@ -927,7 +928,7 @@ func TestDiscoverSystemLogsEndpointError(t *testing.T) {
defer ctrl.Finish()

tester := setup(t, ctrl, ec2.NewBlackholeEC2MetadataClient(), nil)
tester.mockStandardClient.EXPECT().DiscoverPollEndpoint(gomock.Any()).Return(nil,
tester.mockStandardClient.EXPECT().DiscoverPollEndpointWithContext(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil,
fmt.Errorf("Error getting endpoint"))
_, err := tester.client.DiscoverSystemLogsEndpoint(containerInstanceARN, zoneId)
assert.ErrorContains(t, err, "Error getting endpoint",
Expand All @@ -940,13 +941,43 @@ func TestDiscoverNilSystemLogsEndpoint(t *testing.T) {

tester := setup(t, ctrl, ec2.NewBlackholeEC2MetadataClient(), nil)
pollEndpoint := "http://127.0.0.1"
tester.mockStandardClient.EXPECT().DiscoverPollEndpoint(gomock.Any()).
tester.mockStandardClient.EXPECT().DiscoverPollEndpointWithContext(gomock.Any(), gomock.Any(), gomock.Any()).
Return(&ecsmodel.DiscoverPollEndpointOutput{Endpoint: &pollEndpoint}, nil)
_, err := tester.client.DiscoverSystemLogsEndpoint(containerInstanceARN, zoneId)
assert.ErrorContains(t, err, "no system logs endpoint returned",
"Expected error getting system logs endpoint with old response")
}

func TestDiscoverPollEndpointRace(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()

tester := setup(t, ctrl, ec2.NewBlackholeEC2MetadataClient(), nil)
pollEndpoint := "http://127.0.0.1"
// SDK call to DiscoverPollEndpoint should only happen once.
tester.mockStandardClient.EXPECT().DiscoverPollEndpointWithContext(gomock.Any(), gomock.Any(), gomock.Any()).Times(1).Do(func(interface{}, interface{}, ...interface{}) {
// Wait before returning to try and induce the race condition.
time.Sleep(100 * time.Millisecond)
}).Return(&ecsmodel.DiscoverPollEndpointOutput{Endpoint: &pollEndpoint}, nil)

var wg sync.WaitGroup
wg.Add(2)

// First caller.
go func() {
defer wg.Done()
tester.client.DiscoverPollEndpoint(containerInstanceARN)
}()
// Second caller.
go func() {
defer wg.Done()
time.Sleep(10 * time.Millisecond)
tester.client.DiscoverPollEndpoint(containerInstanceARN)
}()

wg.Wait()
}

func TestUpdateContainerInstancesState(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()
Expand Down Expand Up @@ -1041,7 +1072,7 @@ func TestDiscoverPollEndpointCacheMiss(t *testing.T) {

gomock.InOrder(
pollEndpointCache.EXPECT().Get(containerInstanceARN).Return(nil, false, false),
tester.mockStandardClient.EXPECT().DiscoverPollEndpoint(gomock.Any()).Return(pollEndpointOutput, nil),
tester.mockStandardClient.EXPECT().DiscoverPollEndpointWithContext(gomock.Any(), gomock.Any(), gomock.Any()).Return(pollEndpointOutput, nil),
pollEndpointCache.EXPECT().Set(containerInstanceARN, pollEndpointOutput),
)

Expand All @@ -1064,7 +1095,7 @@ func TestDiscoverPollEndpointExpiredButDPEFailed(t *testing.T) {

gomock.InOrder(
pollEndpointCache.EXPECT().Get(containerInstanceARN).Return(pollEndpointOutput, true, false),
tester.mockStandardClient.EXPECT().DiscoverPollEndpoint(gomock.Any()).Return(nil, fmt.Errorf("error!")),
tester.mockStandardClient.EXPECT().DiscoverPollEndpointWithContext(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, fmt.Errorf("error!")),
)

output, err := tester.client.(*ecsClient).discoverPollEndpoint(containerInstanceARN, "")
Expand All @@ -1081,7 +1112,7 @@ func TestDiscoverTelemetryEndpointAfterPollEndpointCacheHit(t *testing.T) {
tester := setup(t, ctrl, ec2.NewBlackholeEC2MetadataClient(), nil,
WithDiscoverPollEndpointCache(pollEndpointCache))
pollEndpoint := "http://127.0.0.1"
tester.mockStandardClient.EXPECT().DiscoverPollEndpoint(gomock.Any()).Return(
tester.mockStandardClient.EXPECT().DiscoverPollEndpointWithContext(gomock.Any(), gomock.Any(), gomock.Any()).Return(
&ecsmodel.DiscoverPollEndpointOutput{
Endpoint: &pollEndpoint,
TelemetryEndpoint: &pollEndpoint,
Expand All @@ -1106,7 +1137,7 @@ func TestDiscoverSystemLogsEndpointAfterCacheHit_HappyPath(t *testing.T) {
tester := setup(t, ctrl, ec2.NewBlackholeEC2MetadataClient(), nil,
WithDiscoverPollEndpointCache(pollEndpointCache))
pollEndpoint := "http://127.0.0.1"
tester.mockStandardClient.EXPECT().DiscoverPollEndpoint(gomock.Any()).Return(
tester.mockStandardClient.EXPECT().DiscoverPollEndpointWithContext(gomock.Any(), gomock.Any(), gomock.Any()).Return(
&ecsmodel.DiscoverPollEndpointOutput{
Endpoint: &pollEndpoint,
SystemLogsEndpoint: &pollEndpoint,
Expand All @@ -1130,15 +1161,15 @@ func TestDiscoverSystemLogsEndpointAfterPollEndpointCacheHit_UnhappyPath(t *test
tester := setup(t, ctrl, ec2.NewBlackholeEC2MetadataClient(), nil,
WithDiscoverPollEndpointCache(pollEndpointCache))
pollEndpoint := "http://127.0.0.1"
tester.mockStandardClient.EXPECT().DiscoverPollEndpoint(gomock.Any()).Return(
tester.mockStandardClient.EXPECT().DiscoverPollEndpointWithContext(gomock.Any(), gomock.Any(), gomock.Any()).Return(
&ecsmodel.DiscoverPollEndpointOutput{
Endpoint: &pollEndpoint,
}, nil).Times(1)
endpoint, err := tester.client.DiscoverPollEndpoint(containerInstanceARN)
assert.NoError(t, err, "Error in DiscoverPollEndpoint")
assert.Equal(t, pollEndpoint, endpoint, "Mismatch in poll endpoint")

tester.mockStandardClient.EXPECT().DiscoverPollEndpoint(gomock.Any()).Return(
tester.mockStandardClient.EXPECT().DiscoverPollEndpointWithContext(gomock.Any(), gomock.Any(), gomock.Any()).Return(
&ecsmodel.DiscoverPollEndpointOutput{
Endpoint: &pollEndpoint,
SystemLogsEndpoint: &pollEndpoint,
Expand Down
1 change: 1 addition & 0 deletions ecs-agent/api/ecs/interface.go
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ type ECSStandardSDK interface {
CreateCluster(*ecs.CreateClusterInput) (*ecs.CreateClusterOutput, error)
RegisterContainerInstance(*ecs.RegisterContainerInstanceInput) (*ecs.RegisterContainerInstanceOutput, error)
DiscoverPollEndpoint(*ecs.DiscoverPollEndpointInput) (*ecs.DiscoverPollEndpointOutput, error)
DiscoverPollEndpointWithContext(ctx aws.Context, input *ecs.DiscoverPollEndpointInput, opts ...request.Option) (*ecs.DiscoverPollEndpointOutput, error)
ListTagsForResource(*ecs.ListTagsForResourceInput) (*ecs.ListTagsForResourceOutput, error)
UpdateContainerInstancesState(input *ecs.UpdateContainerInstancesStateInput) (*ecs.UpdateContainerInstancesStateOutput, error)
}
Expand Down
20 changes: 20 additions & 0 deletions ecs-agent/api/ecs/mocks/api_mocks.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

0 comments on commit deb74a5

Please sign in to comment.